Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add schedule/var_squeeze and schedule/var_unsqueeze #593

Merged
merged 2 commits into from
Jan 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ffi/schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ void init_ffi_schedule(py::module_ &m) {
"factor"_a = -1, "nparts"_a = -1)
.def("var_merge", &Schedule::varMerge, "vardef"_a, "dim"_a)
.def("var_reorder", &Schedule::varReorder, "vardef"_a, "order"_a)
.def("var_unsqueeze", &Schedule::varUnsqueeze, "vardef"_a, "dim"_a)
.def("var_squeeze", &Schedule::varSqueeze, "vardef"_a, "dim"_a)
.def("move_to", &Schedule::moveTo, "stmt"_a, "side"_a, "dst"_a)
.def("inline", &Schedule::inlining, "vardef"_a)
.def("parallelize", &Schedule::parallelize, "loop"_a, "parallel"_a,
Expand Down
28 changes: 28 additions & 0 deletions include/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,34 @@ class Schedule {
*/
void varReorder(const ID &def, const std::vector<int> &order);

/**
* Insert a singleton (1-lengthed) dimension to a variable
*
* This is a utility schedule, which can be used together with `varSplit`,
* `varMerge` and/or `varReorder` to transform a variable to a desired
* shape.
*
* @param def : ID of the VarDef statement of the specific variable
* @param dim : Insert a singleton dimension at the `dim`-th dimension
* @throw InvalidSchedule if the variable is not found or the dimension is
* illegal
*/
void varUnsqueeze(const ID &def, int dim);

/**
* Remove a singleton (1-lengthed) dimension from a variable
*
* This is a utility schedule, which can be used together with `varSplit`,
* `varMerge` and/or `varReorder` to transform a variable to a desired
* shape.
*
* @param def : ID of the VarDef statement of the specific variable
* @param dim : Remove the `dim`-th dimension
* @throw InvalidSchedule if the variable is not found or the dimension is
* illegal
*/
void varSqueeze(const ID &def, int dim);

/**
* Move a statement to a new position
*
Expand Down
19 changes: 11 additions & 8 deletions include/schedule/schedule_log.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ enum class ScheduleType : int {
VarSplit,
VarMerge,
VarReorder,
VarUnsqueeze,
VarSqueeze,
Inline,
Parallelize,
ParallelizeAs,
Expand All @@ -43,14 +45,15 @@ enum class ScheduleType : int {
};

constexpr std::array scheduleTypeNames = {
"split", "reorder", "merge",
"fission", "fuse", "swap",
"blend", "cache", "cache_reduction",
"set_mem_type", "var_split", "var_merge",
"var_reorder", "inline", "parallelize",
"parallelize_as", "unroll", "vectorize",
"separate_tail", "as_matmul", "permute",
"pluto_fuse", "pluto_permute",
"split", "reorder", "merge",
"fission", "fuse", "swap",
"blend", "cache", "cache_reduction",
"set_mem_type", "var_split", "var_merge",
"var_reorder", "var_unsqueeze", "var_squeeze",
"inline", "parallelize", "parallelize_as",
"unroll", "vectorize", "separate_tail",
"as_matmul", "permute", "pluto_fuse",
"pluto_permute",
};
static_assert(scheduleTypeNames.size() == (size_t)ScheduleType::NumTypes);

Expand Down
12 changes: 12 additions & 0 deletions include/schedule/var_squeeze.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#ifndef FREE_TENSOR_VAR_SQUEEZE_H
#define FREE_TENSOR_VAR_SQUEEZE_H

#include <stmt.h>

namespace freetensor {

Stmt varSqueeze(const Stmt &ast, const ID &def, int dim);

}

#endif // FREE_TENSOR_VAR_SQUEEZE_H
12 changes: 12 additions & 0 deletions include/schedule/var_unsqueeze.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#ifndef FREE_TENSOR_VAR_UNSQUEEZE_H
#define FREE_TENSOR_VAR_UNSQUEEZE_H

#include <stmt.h>

namespace freetensor {

Stmt varUnsqueeze(const Stmt &ast, const ID &def, int dim);

}

#endif // FREE_TENSOR_VAR_UNSQUEEZE_H
44 changes: 44 additions & 0 deletions python/freetensor/core/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,50 @@ def var_reorder(self, vardef, order):
"""
return super().var_reorder(self._lookup(vardef), order)

def var_unsqueeze(self, vardef, dim):
"""
Insert a singleton (1-lengthed) dimension to a variable

This is a utility schedule, which can be used together with `var_split`,
`var_merge` and/or `var_reorder` to transform a variable to a desired
shape.

Parameters
----------
vardef : str, ID or Stmt
ID of the VarDef statement of the specific variable
dim : int
Insert a singleton dimension at the `dim`-th dimension

Raises
------
InvalidSchedule
if the variable is not found or the dimension is illegal
"""
return super().var_unsqueeze(self._lookup(vardef), dim)

def var_squeeze(self, vardef, dim):
"""
Remove a singleton (1-lengthed) dimension from a variable

This is a utility schedule, which can be used together with `var_split`,
`var_merge` and/or `var_reorder` to transform a variable to a desired
shape.

Parameters
----------
vardef : str, ID or Stmt
ID of the VarDef statement of the specific variable
dim : int
Remove the `dim`-th dimension

Raises
------
InvalidSchedule
if the variable is not found or the dimension is illegal
"""
return super().var_squeeze(self._lookup(vardef), dim)

def move_to(self, stmt, side, dst):
"""
Move a statement to a new position
Expand Down
3 changes: 2 additions & 1 deletion src/schedule/var_merge.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ Stmt VarMerge::visit(const VarDef &_op) {
if (_op->id() == def_) {
found_ = true;

if (dim_ + 1 >= (int)_op->buffer_->tensor()->shape().size()) {
if (dim_ < 0 ||
dim_ + 1 >= (int)_op->buffer_->tensor()->shape().size()) {
throw InvalidSchedule(FT_MSG << "There is no dimension " << dim_
<< " ~ " << (dim_ + 1)
<< " in variable " << _op->name_);
Expand Down
2 changes: 1 addition & 1 deletion src/schedule/var_split.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Stmt VarSplit::visit(const VarDef &_op) {
if (_op->id() == def_) {
found_ = true;

if (dim_ >= (int)_op->buffer_->tensor()->shape().size()) {
if (dim_ < 0 || dim_ >= (int)_op->buffer_->tensor()->shape().size()) {
throw InvalidSchedule("There is no dimension " +
std::to_string(dim_) + " in variable " +
_op->name_);
Expand Down
79 changes: 79 additions & 0 deletions src/schedule/var_squeeze.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
#include <hash.h>
#include <mutator.h>
#include <schedule.h>
#include <schedule/var_squeeze.h>

namespace freetensor {

namespace {

class VarSqueeze : public Mutator {
ID defId_;
int dim_;
std::string var_;

public:
VarSqueeze(const ID &def, int dim) : defId_(def), dim_(dim) {}

private:
template <typename T> auto visitAcc(const T &_op) {
auto __op = Mutator::visit(_op);
ASSERT(__op->nodeType() == _op->nodeType());
auto op = __op.template as<typename T::Object>();
if (op->var_ == var_) {
op->indices_.erase(op->indices_.begin() + dim_);
}
return op;
}

protected:
Stmt visit(const VarDef &_op) override {
if (_op->id() == defId_) {
if (dim_ < 0 ||
dim_ >= (int)_op->buffer_->tensor()->shape().size()) {
throw InvalidSchedule("Invalid dimension " +
std::to_string(dim_));
}
var_ = _op->name_;
auto __op = Mutator::visit(_op);
ASSERT(__op->nodeType() == ASTNodeType::VarDef);
auto op = __op.as<VarDefNode>();
var_.clear();
if (!HashComparator{}(op->buffer_->tensor()->shape()[dim_],
makeIntConst(1))) {
throw InvalidSchedule("Dimension " + std::to_string(dim_) +
" is not 1-lengthed");
}
op->buffer_->tensor()->shape().erase(
op->buffer_->tensor()->shape().begin() + dim_);
return op;
} else {
return Mutator::visit(_op);
}
}

Expr visit(const Load &op) override { return visitAcc(op); }
Stmt visit(const Store &op) override { return visitAcc(op); }
Stmt visit(const ReduceTo &op) override { return visitAcc(op); }
};

} // Anonymous namespace

Stmt varSqueeze(const Stmt &ast, const ID &def, int dim) {
return VarSqueeze{def, dim}(ast);
}

void Schedule::varSqueeze(const ID &def, int dim) {
beginTransaction();
auto log = appendLog(
MAKE_SCHEDULE_LOG(VarSqueeze, freetensor::varSqueeze, def, dim));
try {
applyLog(log);
commitTransaction();
} catch (const InvalidSchedule &e) {
abortTransaction();
throw InvalidSchedule(log, ast(), e.what());
}
}

} // namespace freetensor
73 changes: 73 additions & 0 deletions src/schedule/var_unsqueeze.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
#include <mutator.h>
#include <schedule.h>
#include <schedule/var_unsqueeze.h>

namespace freetensor {

namespace {

class VarUnsqueeze : public Mutator {
ID defId_;
int dim_;
std::string var_;

public:
VarUnsqueeze(const ID &def, int dim) : defId_(def), dim_(dim) {}

private:
template <typename T> auto visitAcc(const T &_op) {
auto __op = Mutator::visit(_op);
ASSERT(__op->nodeType() == _op->nodeType());
auto op = __op.template as<typename T::Object>();
if (op->var_ == var_) {
op->indices_.insert(op->indices_.begin() + dim_, makeIntConst(0));
}
return op;
}

protected:
Stmt visit(const VarDef &_op) override {
if (_op->id() == defId_) {
if (dim_ < 0 ||
dim_ > (int)_op->buffer_->tensor()->shape().size()) {
throw InvalidSchedule("Invalid dimension " +
std::to_string(dim_));
}
var_ = _op->name_;
auto __op = Mutator::visit(_op);
ASSERT(__op->nodeType() == ASTNodeType::VarDef);
auto op = __op.as<VarDefNode>();
var_.clear();
op->buffer_->tensor()->shape().insert(
op->buffer_->tensor()->shape().begin() + dim_, makeIntConst(1));
return op;
} else {
return Mutator::visit(_op);
}
}

Expr visit(const Load &op) override { return visitAcc(op); }
Stmt visit(const Store &op) override { return visitAcc(op); }
Stmt visit(const ReduceTo &op) override { return visitAcc(op); }
};

} // Anonymous namespace

Stmt varUnsqueeze(const Stmt &ast, const ID &def, int dim) {
return VarUnsqueeze{def, dim}(ast);
}

void Schedule::varUnsqueeze(const ID &def, int dim) {
beginTransaction();
auto log = appendLog(
MAKE_SCHEDULE_LOG(VarUnsqueeze, freetensor::varUnsqueeze, def, dim));
try {
applyLog(log);
commitTransaction();
} catch (const InvalidSchedule &e) {
abortTransaction();
throw InvalidSchedule(log, ast(), e.what());
}
}

} // namespace freetensor
58 changes: 58 additions & 0 deletions test/30.schedule/test_var_squeeze.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import freetensor as ft
import pytest


def test_basic():
with ft.VarDef([("x", (4, 8), "int32", "input", "cpu"),
("y", (4, 8), "int32", "output", "cpu")]) as (x, y):
ft.MarkLabel("Dc")
with ft.VarDef("c", (4, 1, 8), "int32", "cache", "cpu") as c:
with ft.For("i", 0, 4) as i:
with ft.For("j", 0, 8) as j:
c[i, 0, j] = x[i, j] * 2
with ft.For("i", 0, 4) as i:
with ft.For("j", 0, 8) as j:
y[i, j] = c[i, 0, j] + 1
ast = ft.pop_ast(verbose=True)
s = ft.Schedule(ast)
s.var_squeeze("Dc", 1)
ast = s.ast()
print(ast)
ast = ft.lower(ast, skip_passes=['prop_one_time_use'], verbose=1)

with ft.VarDef([("x", (4, 8), "int32", "input", "cpu"),
("y", (4, 8), "int32", "output", "cpu")]) as (x, y):
ft.MarkLabel("Dc")
with ft.VarDef("c", (4, 8), "int32", "cache", "cpu") as c:
with ft.For("i", 0, 4) as i:
with ft.For("j", 0, 8) as j:
c[i, j] = x[i, j] * 2
with ft.For("i", 0, 4) as i:
with ft.For("j", 0, 8) as j:
y[i, j] = c[i, j] + 1
std = ft.pop_ast()


def test_not_singleton():
ft.MarkLabel("Dy")
with ft.VarDef("y", (8,), "int32", "output", "cpu") as y:
with ft.For("i", 0, 8) as i:
y[i] = i
ast = ft.pop_ast(verbose=True)
s = ft.Schedule(ast)
with pytest.raises(ft.InvalidSchedule):
s.var_squeeze("Dy", 0)
ast_ = s.ast() # Should not changed
assert ast_.match(ast)


def test_out_of_range():
ft.MarkLabel("Dy")
with ft.VarDef("y", (1,), "int32", "output", "cpu") as y:
y[0] = 1
ast = ft.pop_ast(verbose=True)
s = ft.Schedule(ast)
with pytest.raises(ft.InvalidSchedule):
s.var_squeeze("Dy", 1)
ast_ = s.ast() # Should not changed
assert ast_.match(ast)
Loading
Loading