From 48b3d3c2ff9f7b4ee1e44bd28bb5b097bfcd9eac Mon Sep 17 00:00:00 2001 From: Shizhi Tang Date: Sat, 20 Jan 2024 20:13:34 +0800 Subject: [PATCH 1/2] Add schedule/var_squeeze and schedule/var_unsqueeze --- ffi/schedule.cc | 2 + include/schedule.h | 28 +++++++++ include/schedule/schedule_log.h | 19 ++++--- include/schedule/var_squeeze.h | 12 ++++ include/schedule/var_unsqueeze.h | 12 ++++ python/freetensor/core/schedule.py | 44 ++++++++++++++ src/schedule/var_merge.cc | 3 +- src/schedule/var_split.cc | 2 +- src/schedule/var_squeeze.cc | 79 ++++++++++++++++++++++++++ src/schedule/var_unsqueeze.cc | 73 ++++++++++++++++++++++++ test/30.schedule/test_var_squeeze.py | 58 +++++++++++++++++++ test/30.schedule/test_var_unsqueeze.py | 46 +++++++++++++++ 12 files changed, 368 insertions(+), 10 deletions(-) create mode 100644 include/schedule/var_squeeze.h create mode 100644 include/schedule/var_unsqueeze.h create mode 100644 src/schedule/var_squeeze.cc create mode 100644 src/schedule/var_unsqueeze.cc create mode 100644 test/30.schedule/test_var_squeeze.py create mode 100644 test/30.schedule/test_var_unsqueeze.py diff --git a/ffi/schedule.cc b/ffi/schedule.cc index 224a0a3ce..7b965c780 100644 --- a/ffi/schedule.cc +++ b/ffi/schedule.cc @@ -127,6 +127,8 @@ void init_ffi_schedule(py::module_ &m) { "factor"_a = -1, "nparts"_a = -1) .def("var_merge", &Schedule::varMerge, "vardef"_a, "dim"_a) .def("var_reorder", &Schedule::varReorder, "vardef"_a, "order"_a) + .def("var_unsqueeze", &Schedule::varUnsqueeze, "vardef"_a, "dim"_a) + .def("var_squeeze", &Schedule::varSqueeze, "vardef"_a, "dim"_a) .def("move_to", &Schedule::moveTo, "stmt"_a, "side"_a, "dst"_a) .def("inline", &Schedule::inlining, "vardef"_a) .def("parallelize", &Schedule::parallelize, "loop"_a, "parallel"_a, diff --git a/include/schedule.h b/include/schedule.h index c0c421260..d3759b1fb 100644 --- a/include/schedule.h +++ b/include/schedule.h @@ -554,6 +554,34 @@ class Schedule { */ void varReorder(const ID &def, const std::vector &order); + /** + * Insert a singleton (1-lengthed) dimension to a variable + * + * This is a utility schedule, which can be used together with `varSplit`, + * `varMerge` and/or `varReorder` to transform a variable to a desired + * shape. + * + * @param def : ID of the VarDef statement of the specific variable + * @param dim : Insert a singleton dimension at the `dim`-th dimension + * @throw InvalidSchedule if the variable is not found or the dimension is + * illegal + */ + void varUnsqueeze(const ID &def, int dim); + + /** + * Remove a singleton (1-lengthed) dimension from a variable + * + * This is a utility schedule, which can be used together with `varSplit`, + * `varMerge` and/or `varReorder` to transform a variable to a desired + * shape. + * + * @param def : ID of the VarDef statement of the specific variable + * @param dim : Remove the `dim`-th dimension + * @throw InvalidSchedule if the variable is not found or the dimension is + * illegal + */ + void varSqueeze(const ID &def, int dim); + /** * Move a statement to a new position * diff --git a/include/schedule/schedule_log.h b/include/schedule/schedule_log.h index 902e6f2dc..08287c1bc 100644 --- a/include/schedule/schedule_log.h +++ b/include/schedule/schedule_log.h @@ -28,6 +28,8 @@ enum class ScheduleType : int { VarSplit, VarMerge, VarReorder, + VarUnsqueeze, + VarSqueeze, Inline, Parallelize, ParallelizeAs, @@ -43,14 +45,15 @@ enum class ScheduleType : int { }; constexpr std::array scheduleTypeNames = { - "split", "reorder", "merge", - "fission", "fuse", "swap", - "blend", "cache", "cache_reduction", - "set_mem_type", "var_split", "var_merge", - "var_reorder", "inline", "parallelize", - "parallelize_as", "unroll", "vectorize", - "separate_tail", "as_matmul", "permute", - "pluto_fuse", "pluto_permute", + "split", "reorder", "merge", + "fission", "fuse", "swap", + "blend", "cache", "cache_reduction", + "set_mem_type", "var_split", "var_merge", + "var_reorder", "var_unsqueeze", "var_squeeze", + "inline", "parallelize", "parallelize_as", + "unroll", "vectorize", "separate_tail", + "as_matmul", "permute", "pluto_fuse", + "pluto_permute", }; static_assert(scheduleTypeNames.size() == (size_t)ScheduleType::NumTypes); diff --git a/include/schedule/var_squeeze.h b/include/schedule/var_squeeze.h new file mode 100644 index 000000000..67d973e17 --- /dev/null +++ b/include/schedule/var_squeeze.h @@ -0,0 +1,12 @@ +#ifndef FREE_TENSOR_VAR_SQUEEZE_H +#define FREE_TENSOR_VAR_SQUEEZE_H + +#include + +namespace freetensor { + +Stmt varSqueeze(const Stmt &ast, const ID &def, int dim); + +} + +#endif // FREE_TENSOR_VAR_SQUEEZE_H diff --git a/include/schedule/var_unsqueeze.h b/include/schedule/var_unsqueeze.h new file mode 100644 index 000000000..a0508a7d6 --- /dev/null +++ b/include/schedule/var_unsqueeze.h @@ -0,0 +1,12 @@ +#ifndef FREE_TENSOR_VAR_UNSQUEEZE_H +#define FREE_TENSOR_VAR_UNSQUEEZE_H + +#include + +namespace freetensor { + +Stmt varUnsqueeze(const Stmt &ast, const ID &def, int dim); + +} + +#endif // FREE_TENSOR_VAR_UNSQUEEZE_H diff --git a/python/freetensor/core/schedule.py b/python/freetensor/core/schedule.py index f2a8a85d0..8fc3d4160 100644 --- a/python/freetensor/core/schedule.py +++ b/python/freetensor/core/schedule.py @@ -565,6 +565,50 @@ def var_reorder(self, vardef, order): """ return super().var_reorder(self._lookup(vardef), order) + def var_unsqueeze(self, vardef, dim): + """ + Insert a singleton (1-lengthed) dimension to a variable + + This is a utility schedule, which can be used together with `varSplit`, + `varMerge` and/or `varReorder` to transform a variable to a desired + shape. + + Parameters + ---------- + vardef : str, ID or Stmt + ID of the VarDef statement of the specific variable + dim : int + Insert a singleton dimension at the `dim`-th dimension + + Raises + ------ + InvalidSchedule + if the variable is not found or the dimension is illegal + """ + return super().var_unsqueeze(self._lookup(vardef), dim) + + def var_squeeze(self, vardef, dim): + """ + Remove a singleton (1-lengthed) dimension from a variable + + This is a utility schedule, which can be used together with `varSplit`, + `varMerge` and/or `varReorder` to transform a variable to a desired + shape. + + Parameters + ---------- + vardef : str, ID or Stmt + ID of the VarDef statement of the specific variable + dim : int + Remove the `dim`-th dimension + + Raises + ------ + InvalidSchedule + if the variable is not found or the dimension is illegal + """ + return super().var_squeeze(self._lookup(vardef), dim) + def move_to(self, stmt, side, dst): """ Move a statement to a new position diff --git a/src/schedule/var_merge.cc b/src/schedule/var_merge.cc index 59590b81e..42de2f006 100644 --- a/src/schedule/var_merge.cc +++ b/src/schedule/var_merge.cc @@ -7,7 +7,8 @@ Stmt VarMerge::visit(const VarDef &_op) { if (_op->id() == def_) { found_ = true; - if (dim_ + 1 >= (int)_op->buffer_->tensor()->shape().size()) { + if (dim_ < 0 || + dim_ + 1 >= (int)_op->buffer_->tensor()->shape().size()) { throw InvalidSchedule(FT_MSG << "There is no dimension " << dim_ << " ~ " << (dim_ + 1) << " in variable " << _op->name_); diff --git a/src/schedule/var_split.cc b/src/schedule/var_split.cc index 7d785235b..922f3bc7d 100644 --- a/src/schedule/var_split.cc +++ b/src/schedule/var_split.cc @@ -7,7 +7,7 @@ Stmt VarSplit::visit(const VarDef &_op) { if (_op->id() == def_) { found_ = true; - if (dim_ >= (int)_op->buffer_->tensor()->shape().size()) { + if (dim_ < 0 || dim_ >= (int)_op->buffer_->tensor()->shape().size()) { throw InvalidSchedule("There is no dimension " + std::to_string(dim_) + " in variable " + _op->name_); diff --git a/src/schedule/var_squeeze.cc b/src/schedule/var_squeeze.cc new file mode 100644 index 000000000..e63bec6db --- /dev/null +++ b/src/schedule/var_squeeze.cc @@ -0,0 +1,79 @@ +#include +#include +#include +#include + +namespace freetensor { + +namespace { + +class VarSqueeze : public Mutator { + ID defId_; + int dim_; + std::string var_; + + public: + VarSqueeze(const ID &def, int dim) : defId_(def), dim_(dim) {} + + private: + template auto visitAcc(const T &_op) { + auto __op = Mutator::visit(_op); + ASSERT(__op->nodeType() == _op->nodeType()); + auto op = __op.template as(); + if (op->var_ == var_) { + op->indices_.erase(op->indices_.begin() + dim_); + } + return op; + } + + protected: + Stmt visit(const VarDef &_op) override { + if (_op->id() == defId_) { + if (dim_ < 0 || + dim_ >= (int)_op->buffer_->tensor()->shape().size()) { + throw InvalidSchedule("Invalid dimension " + + std::to_string(dim_)); + } + var_ = _op->name_; + auto __op = Mutator::visit(_op); + ASSERT(__op->nodeType() == ASTNodeType::VarDef); + auto op = __op.as(); + var_.clear(); + if (!HashComparator{}(op->buffer_->tensor()->shape()[dim_], + makeIntConst(1))) { + throw InvalidSchedule("Dimension " + std::to_string(dim_) + + " is not 1-lengthed"); + } + op->buffer_->tensor()->shape().erase( + op->buffer_->tensor()->shape().begin() + dim_); + return op; + } else { + return Mutator::visit(_op); + } + } + + Expr visit(const Load &op) override { return visitAcc(op); } + Stmt visit(const Store &op) override { return visitAcc(op); } + Stmt visit(const ReduceTo &op) override { return visitAcc(op); } +}; + +} // Anonymous namespace + +Stmt varSqueeze(const Stmt &ast, const ID &def, int dim) { + return VarSqueeze{def, dim}(ast); +} + +void Schedule::varSqueeze(const ID &def, int dim) { + beginTransaction(); + auto log = appendLog( + MAKE_SCHEDULE_LOG(VarSqueeze, freetensor::varSqueeze, def, dim)); + try { + applyLog(log); + commitTransaction(); + } catch (const InvalidSchedule &e) { + abortTransaction(); + throw InvalidSchedule(log, ast(), e.what()); + } +} + +} // namespace freetensor diff --git a/src/schedule/var_unsqueeze.cc b/src/schedule/var_unsqueeze.cc new file mode 100644 index 000000000..d8f775c0e --- /dev/null +++ b/src/schedule/var_unsqueeze.cc @@ -0,0 +1,73 @@ +#include +#include +#include + +namespace freetensor { + +namespace { + +class VarUnsqueeze : public Mutator { + ID defId_; + int dim_; + std::string var_; + + public: + VarUnsqueeze(const ID &def, int dim) : defId_(def), dim_(dim) {} + + private: + template auto visitAcc(const T &_op) { + auto __op = Mutator::visit(_op); + ASSERT(__op->nodeType() == _op->nodeType()); + auto op = __op.template as(); + if (op->var_ == var_) { + op->indices_.insert(op->indices_.begin() + dim_, makeIntConst(0)); + } + return op; + } + + protected: + Stmt visit(const VarDef &_op) override { + if (_op->id() == defId_) { + if (dim_ < 0 || + dim_ > (int)_op->buffer_->tensor()->shape().size()) { + throw InvalidSchedule("Invalid dimension " + + std::to_string(dim_)); + } + var_ = _op->name_; + auto __op = Mutator::visit(_op); + ASSERT(__op->nodeType() == ASTNodeType::VarDef); + auto op = __op.as(); + var_.clear(); + op->buffer_->tensor()->shape().insert( + op->buffer_->tensor()->shape().begin() + dim_, makeIntConst(1)); + return op; + } else { + return Mutator::visit(_op); + } + } + + Expr visit(const Load &op) override { return visitAcc(op); } + Stmt visit(const Store &op) override { return visitAcc(op); } + Stmt visit(const ReduceTo &op) override { return visitAcc(op); } +}; + +} // Anonymous namespace + +Stmt varUnsqueeze(const Stmt &ast, const ID &def, int dim) { + return VarUnsqueeze{def, dim}(ast); +} + +void Schedule::varUnsqueeze(const ID &def, int dim) { + beginTransaction(); + auto log = appendLog( + MAKE_SCHEDULE_LOG(VarUnsqueeze, freetensor::varUnsqueeze, def, dim)); + try { + applyLog(log); + commitTransaction(); + } catch (const InvalidSchedule &e) { + abortTransaction(); + throw InvalidSchedule(log, ast(), e.what()); + } +} + +} // namespace freetensor diff --git a/test/30.schedule/test_var_squeeze.py b/test/30.schedule/test_var_squeeze.py new file mode 100644 index 000000000..ac1ae7537 --- /dev/null +++ b/test/30.schedule/test_var_squeeze.py @@ -0,0 +1,58 @@ +import freetensor as ft +import pytest + + +def test_basic(): + with ft.VarDef([("x", (4, 8), "int32", "input", "cpu"), + ("y", (4, 8), "int32", "output", "cpu")]) as (x, y): + ft.MarkLabel("Dc") + with ft.VarDef("c", (4, 1, 8), "int32", "cache", "cpu") as c: + with ft.For("i", 0, 4) as i: + with ft.For("j", 0, 8) as j: + c[i, 0, j] = x[i, j] * 2 + with ft.For("i", 0, 4) as i: + with ft.For("j", 0, 8) as j: + y[i, j] = c[i, 0, j] + 1 + ast = ft.pop_ast(verbose=True) + s = ft.Schedule(ast) + s.var_squeeze("Dc", 1) + ast = s.ast() + print(ast) + ast = ft.lower(ast, skip_passes=['prop_one_time_use'], verbose=1) + + with ft.VarDef([("x", (4, 8), "int32", "input", "cpu"), + ("y", (4, 8), "int32", "output", "cpu")]) as (x, y): + ft.MarkLabel("Dc") + with ft.VarDef("c", (4, 8), "int32", "cache", "cpu") as c: + with ft.For("i", 0, 4) as i: + with ft.For("j", 0, 8) as j: + c[i, j] = x[i, j] * 2 + with ft.For("i", 0, 4) as i: + with ft.For("j", 0, 8) as j: + y[i, j] = c[i, j] + 1 + std = ft.pop_ast() + + +def test_not_singleton(): + ft.MarkLabel("Dy") + with ft.VarDef("y", (8,), "int32", "output", "cpu") as y: + with ft.For("i", 0, 8) as i: + y[i] = i + ast = ft.pop_ast(verbose=True) + s = ft.Schedule(ast) + with pytest.raises(ft.InvalidSchedule): + s.var_squeeze("Dy", 0) + ast_ = s.ast() # Should not changed + assert ast_.match(ast) + + +def test_out_of_range(): + ft.MarkLabel("Dy") + with ft.VarDef("y", (1,), "int32", "output", "cpu") as y: + y[0] = 1 + ast = ft.pop_ast(verbose=True) + s = ft.Schedule(ast) + with pytest.raises(ft.InvalidSchedule): + s.var_squeeze("Dy", 1) + ast_ = s.ast() # Should not changed + assert ast_.match(ast) diff --git a/test/30.schedule/test_var_unsqueeze.py b/test/30.schedule/test_var_unsqueeze.py new file mode 100644 index 000000000..72971013c --- /dev/null +++ b/test/30.schedule/test_var_unsqueeze.py @@ -0,0 +1,46 @@ +import freetensor as ft +import pytest + + +def test_basic(): + with ft.VarDef([("x", (4, 8), "int32", "input", "cpu"), + ("y", (4, 8), "int32", "output", "cpu")]) as (x, y): + ft.MarkLabel("Dc") + with ft.VarDef("c", (4, 8), "int32", "cache", "cpu") as c: + with ft.For("i", 0, 4) as i: + with ft.For("j", 0, 8) as j: + c[i, j] = x[i, j] * 2 + with ft.For("i", 0, 4) as i: + with ft.For("j", 0, 8) as j: + y[i, j] = c[i, j] + 1 + ast = ft.pop_ast(verbose=True) + s = ft.Schedule(ast) + s.var_unsqueeze("Dc", 1) + ast = s.ast() + print(ast) + ast = ft.lower(ast, skip_passes=['prop_one_time_use'], verbose=1) + + with ft.VarDef([("x", (4, 8), "int32", "input", "cpu"), + ("y", (4, 8), "int32", "output", "cpu")]) as (x, y): + ft.MarkLabel("Dc") + with ft.VarDef("c", (4, 1, 8), "int32", "cache", "cpu") as c: + with ft.For("i", 0, 4) as i: + with ft.For("j", 0, 8) as j: + c[i, 0, j] = x[i, j] * 2 + with ft.For("i", 0, 4) as i: + with ft.For("j", 0, 8) as j: + y[i, j] = c[i, 0, j] + 1 + std = ft.pop_ast() + + +def test_out_of_range(): + ft.MarkLabel("Dy") + with ft.VarDef("y", (8,), "int32", "output", "cpu") as y: + with ft.For("i", 0, 8) as i: + y[i] = i + ast = ft.pop_ast(verbose=True) + s = ft.Schedule(ast) + with pytest.raises(ft.InvalidSchedule): + s.var_unsqueeze("Dy", 2) + ast_ = s.ast() # Should not changed + assert ast_.match(ast) From 9c00e49ab3ebfeb069ab5caecb2b041e99ec5269 Mon Sep 17 00:00:00 2001 From: Shizhi Tang Date: Sat, 20 Jan 2024 20:16:31 +0800 Subject: [PATCH 2/2] Minor documentation change --- python/freetensor/core/schedule.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/freetensor/core/schedule.py b/python/freetensor/core/schedule.py index 8fc3d4160..f999835a0 100644 --- a/python/freetensor/core/schedule.py +++ b/python/freetensor/core/schedule.py @@ -569,8 +569,8 @@ def var_unsqueeze(self, vardef, dim): """ Insert a singleton (1-lengthed) dimension to a variable - This is a utility schedule, which can be used together with `varSplit`, - `varMerge` and/or `varReorder` to transform a variable to a desired + This is a utility schedule, which can be used together with `var_split`, + `var_merge` and/or `var_reorder` to transform a variable to a desired shape. Parameters @@ -591,8 +591,8 @@ def var_squeeze(self, vardef, dim): """ Remove a singleton (1-lengthed) dimension from a variable - This is a utility schedule, which can be used together with `varSplit`, - `varMerge` and/or `varReorder` to transform a variable to a desired + This is a utility schedule, which can be used together with `var_split`, + `var_merge` and/or `var_reorder` to transform a variable to a desired shape. Parameters