diff --git a/ffi/array.cc b/ffi/array.cc index 36b2a262d..bbb0eb8aa 100644 --- a/ffi/array.cc +++ b/ffi/array.cc @@ -52,6 +52,8 @@ static DataType dtypeFromPyTorch(torch::ScalarType t) { return DataType::Int32; case torch::ScalarType::Long: return DataType::Int64; + case torch::ScalarType::Half: + return DataType::Float16; case torch::ScalarType::Float: return DataType::Float32; case torch::ScalarType::Double: @@ -69,6 +71,8 @@ static torch::ScalarType dtypeToPyTorch(DataType dtype) { return torch::ScalarType::Int; case DataType::Int64: return torch::ScalarType::Long; + case DataType::Float16: + return torch::ScalarType::Half; case DataType::Float32: return torch::ScalarType::Float; case DataType::Float64: @@ -116,9 +120,10 @@ void init_ffi_array(py::module_ &m) { // or it will all end up in float64 (the first initializer) throw DriverError( "Unsupported data type or strides from a NumPy Array. " - "Please " - "use freetensor.array factory function, instead of " - "freetensor.Array, for strided arrays"); + "If you are using strided arrays, please use " + "freetensor.array factory function, instead of " + "freetensor.Array. If you are using float16, please use " + "the PyTorch interface instead."); }), "data"_a, "dont_drop_borrow"_a = false, "moved"_a = false) .def("__eq__", [](const Ref &lhs, const Ref &rhs) { @@ -172,6 +177,14 @@ void init_ffi_array(py::module_ &m) { SHARE_TO_NUMPY(int64_t, DataType::Int64) SHARE_TO_NUMPY(int32_t, DataType::Int32) SHARE_TO_NUMPY(bool, DataType::Bool) + case DataType::Float16: + // TODO: Support fp16 after PyBind11 for NumPy fp16 is + // available. Status: + // https://github.com/pybind/pybind11/issues/1776, + // https://github.com/pybind/pybind11/issues/4061 + throw DriverError( + "NumPy interface for float16 is not supported yet. Please " + "use the PyTorch interface instead."); default: ASSERT(false); } diff --git a/include/codegen/code_gen_c.h b/include/codegen/code_gen_c.h index e270c7236..b501b8e35 100644 --- a/include/codegen/code_gen_c.h +++ b/include/codegen/code_gen_c.h @@ -22,7 +22,7 @@ template class CodeGenC : public CodeGen { const std::vector &returns) : params_(params), returns_(returns) {} - static std::string gen(DataType dtype); + virtual std::string gen(const DataType &dtype); protected: virtual void genAlloc(const Ref &tensor, const std::string &rawPtr, diff --git a/include/codegen/code_gen_cuda.h b/include/codegen/code_gen_cuda.h index b915dc129..a8d006bcb 100644 --- a/include/codegen/code_gen_cuda.h +++ b/include/codegen/code_gen_cuda.h @@ -47,6 +47,8 @@ class CodeGenCUDA : public CodeGenC { Expr globalSize() const { return globalSize_; } + std::string gen(const DataType &dtype) override; + private: bool inKernel() const; @@ -87,6 +89,7 @@ class CodeGenCUDA : public CodeGenC { void visit(const Abs &op) override; void visit(const Floor &op) override; void visit(const Ceil &op) override; + void visit(const Cast &op) override; void visit(const ReduceTo &op) override; void visit(const Var &op) override; void visit(const For &op) override; diff --git a/include/pass/const_fold.h b/include/pass/const_fold.h index c596f7ec3..26e9fe7b2 100644 --- a/include/pass/const_fold.h +++ b/include/pass/const_fold.h @@ -118,6 +118,7 @@ class ConstFold : public Mutator { case DataType::Int32: case DataType::Int64: return wrap(int64_t(v)); + case DataType::Float16: case DataType::Float32: case DataType::Float64: return wrap(double(v)); diff --git a/include/type/data_type.h b/include/type/data_type.h index 8ec191e88..fdccb8304 100644 --- a/include/type/data_type.h +++ b/include/type/data_type.h @@ -12,6 +12,7 @@ namespace freetensor { enum class BaseDataType : size_t { Void = 0, // Returns nothing. It is a Unit Type + Float16, Float32, Float64, Int32, @@ -24,7 +25,8 @@ enum class BaseDataType : size_t { }; constexpr std::array baseDataTypeNames = { - "void", "float32", "float64", "int32", "int64", "bool", "custom", "never", + "void", "float16", "float32", "float64", "int32", + "int64", "bool", "custom", "never", }; static_assert(baseDataTypeNames.size() == (size_t)BaseDataType::NumTypes); @@ -105,6 +107,7 @@ class DataType { // and remove the following lines constexpr static auto Bool = BaseDataType::Bool; constexpr static auto Custom = BaseDataType::Custom; + constexpr static auto Float16 = BaseDataType::Float16; constexpr static auto Float32 = BaseDataType::Float32; constexpr static auto Float64 = BaseDataType::Float64; constexpr static auto Int32 = BaseDataType::Int32; diff --git a/python/freetensor/core/staging.py b/python/freetensor/core/staging.py index 04cc50829..09abe49c8 100644 --- a/python/freetensor/core/staging.py +++ b/python/freetensor/core/staging.py @@ -82,8 +82,8 @@ def process_annotating_comments(src: str): for line in src.splitlines(): indent = re.match('\\s*', line)[0] rest_line = line[len(indent):] - if rest_line.startswith('#! '): - arg = rest_line[3:].replace('"', '\\"') + if rest_line.startswith('#!'): + arg = rest_line[2:].strip().replace('"', '\\"') new_src.append(f'{indent}__staging_overload__.metadata("{arg}")') else: new_src.append(line) diff --git a/runtime/gpu_runtime.h b/runtime/gpu_runtime.h index d311130d5..5d8c2d99f 100644 --- a/runtime/gpu_runtime.h +++ b/runtime/gpu_runtime.h @@ -9,8 +9,11 @@ #include #include -#include "cutlass/cutlass.h" -#include "cutlass/gemm/device/gemm.h" +#include + +#include +#include +#include #include "gpu_context.h" diff --git a/src/codegen/code_gen_cuda.cc b/src/codegen/code_gen_cuda.cc index aaf9d5f5f..48c8e97f9 100644 --- a/src/codegen/code_gen_cuda.cc +++ b/src/codegen/code_gen_cuda.cc @@ -15,12 +15,14 @@ namespace freetensor { -static std::string genCUBLASType(DataType dtype) { +static std::string genCUBLASType(const DataType &dtype) { switch (dtype.base()) { case DataType::Float64: return "CUDA_R_64F"; case DataType::Float32: return "CUDA_R_32F"; + case DataType::Float16: + return "CUDA_R_16F"; case DataType::Int64: return "CUDA_R_64I"; case DataType::Int32: @@ -30,9 +32,32 @@ static std::string genCUBLASType(DataType dtype) { } } +static std::string genCUTLASSType(const DataType &dtype) { + switch (dtype.base()) { + case DataType::Float64: + return "double"; + case DataType::Float32: + return "float"; + case DataType::Float16: + return "cutlass::half_t"; + case DataType::Int64: + return "int64_t"; + case DataType::Int32: + return "int32_t"; + case DataType::Bool: + return "bool"; + default: + ASSERT(false); + } +} + static bool canUseTensorCore(const Ref &target, DataType dtypeA, DataType dtypeB, DataType dtypeC) { - // TODO: fp16 is supported after sm70 + if (target->computeCapability().first >= 7 && dtypeA == DataType::Float16 && + dtypeB == DataType::Float16 && + (dtypeC == DataType::Float16 || dtypeC == DataType::Float32)) { + return true; + } if (target->computeCapability().first >= 8 && dtypeA == DataType::Float64 && dtypeB == DataType::Float64 && dtypeC == DataType::Float64) { return true; @@ -49,7 +74,17 @@ CodeGenCUDA::genMdPtrType(const VarDef &def, bool isConst) { // Use pointer instead of reference for scalars, because when passing an // argument from host to a kernel, a reference means copy the value from // CPU to GPU, while a pointer means passing the address + + // NOTE: `[=]` implicitly capturing `this` is deprecated in C++20. Using + // `[=]` will trigger a warning in GCC (because of deprecation), but + // using + // `[=, this]` will trigger a warning in Clang<17 (because it will think + // `this` is duplicated). +#if defined(__clang__) && __clang_major__ < 17 return [=](std::ostream &os) -> std::ostream & { +#else + return [=, this](std::ostream &os) -> std::ostream & { +#endif if (isConst) { os << "const "; } @@ -77,6 +112,14 @@ void CodeGenCUDA::genMdPtrDef(const VarDef &def, CodeGenC::genMdPtrDef(def, genRawPtr, isConst); } +std::string CodeGenCUDA::gen(const DataType &dtype) { + if (dtype == DataType::Float16) { + return "__half"; + } else { + return CodeGenC::gen(dtype); + } +} + void CodeGenCUDA::genAlloc(const Ref &tensor, const std::string &rawPtr, const std::string &shapePtr, const std::string &dimPtr) { @@ -383,6 +426,72 @@ void CodeGenCUDA::visit(const Ceil &op) { os() << ")"; } +void CodeGenCUDA::visit(const Cast &op) { + if (op->destType_.base() == DataType::Float16) { + switch (op->expr_->dtype().base()) { + case DataType::Int32: + os() << "__int2half_rn("; + (*this)(op->expr_); + os() << ")"; + break; + case DataType::Int64: + os() << "__ll2half_rn("; + (*this)(op->expr_); + os() << ")"; + break; + case DataType::Float16: + (*this)(op->expr_); + break; + case DataType::Float32: + os() << "__float2half_rn("; + (*this)(op->expr_); + os() << ")"; + break; + case DataType::Float64: + os() << "__double2half("; // Always `_rn` (round to nearest even) + (*this)(op->expr_); + os() << ")"; + break; + default: + throw InvalidProgram("Converting from " + + freetensor::toString(op->dtype()) + + " to float16 is not supported"); + } + } else if (op->expr_->dtype().base() == DataType::Float16) { + switch (op->destType_.base()) { + case DataType::Int32: + os() << "__half2int_rn("; + (*this)(op->expr_); + os() << ")"; + break; + case DataType::Int64: + os() << "__half2ll_rn("; + (*this)(op->expr_); + os() << ")"; + break; + case DataType::Float16: + (*this)(op->expr_); + break; + case DataType::Float32: + os() << "__half2float("; // Short to long, no rounding is needed + (*this)(op->expr_); + os() << ")"; + break; + case DataType::Float64: + os() << "__half2double("; // Short to long, no rounding is needed + (*this)(op->expr_); + os() << ")"; + break; + default: + throw InvalidProgram("Converting from float16 to " + + freetensor::toString(op->dtype()) + + " is not supported"); + } + } else { + CodeGenC::visit(op); + } +} + void CodeGenCUDA::visit(const Store &op) { if (buffer(op->var_)->mtype() == MemType::GPUWarp) { auto id = mangle(op->var_); @@ -843,50 +952,64 @@ void CodeGenCUDA::visit(const MatMul &op) { } makeIndent(); - os() << "cutlass::gemm::device::Gemm<" << gen(op->a_->dtype()) << ", " + beginBlock(); + makeIndent(); + os() << "using Gemm = cutlass::gemm::device::Gemm<" + << genCUTLASSType(op->a_->dtype()) << ", " << (transA ? "cutlass::layout::ColumnMajor" : "cutlass::layout::RowMajor") - << ", " << gen(op->b_->dtype()) << ", " + << ", " << genCUTLASSType(op->b_->dtype()) << ", " << (transB ? "cutlass::layout::ColumnMajor" : "cutlass::layout::RowMajor") - << ", " << gen(op->c_->dtype()) << ", " + << ", " << genCUTLASSType(op->c_->dtype()) << ", " << (transC ? "cutlass::layout::ColumnMajor" : "cutlass::layout::RowMajor") - << ", " << gen(op->c_->dtype()) // TODO: accumulator type + << ", " + << genCUTLASSType(op->c_->dtype()) // TODO: accumulator type << ", " << (canUseTensorCore(target_, op->a_->dtype(), op->b_->dtype(), op->c_->dtype()) ? "cutlass::arch::OpClassTensorOp" : "cutlass::arch::OpClassSimt") - << ", FT_CUTLASS_ARCH> gemm;" << std::endl; + << ", FT_CUTLASS_ARCH>;" << std::endl; makeIndent(); - os() << "checkCutlassError(gemm({{"; + os() << "Gemm gemm;" << std::endl; + // In order for clearer error message, please keep the explicit argument + // types in the following statement. + makeIndent(); + os() << "checkCutlassError(gemm(Gemm::Arguments{{"; (*this)(m); os() << ", "; (*this)(n); os() << ", "; (*this)(k); - os() << "}, {&"; + os() << "}, Gemm::TensorRefA{(const " << genCUTLASSType(op->a_->dtype()) + << "*)&"; (*this)(a); os() << ", "; (*this)(lda); - os() << "}, {&"; + os() << "}, Gemm::TensorRefB{(const " << genCUTLASSType(op->b_->dtype()) + << "*)&"; (*this)(b); os() << ", "; (*this)(ldb); - os() << "}, {&"; + os() << "}, Gemm::TensorRefC{(const " << genCUTLASSType(op->c_->dtype()) + << "*)&"; (*this)(c); os() << ", "; (*this)(ldc); - os() << "}, {&"; + os() << "}, Gemm::TensorRefD{(" << genCUTLASSType(op->c_->dtype()) + << "*)&"; (*this)(c); os() << ", "; (*this)(ldc); - os() << "}, {"; + os() << "}, Gemm::EpilogueOutputOp::Params{(" + << genCUTLASSType(op->c_->dtype()) << ")("; (*this)(op->alpha_); - os() << ", "; + os() << "), (" << genCUTLASSType(op->c_->dtype()) << ")("; (*this)(op->beta_); - os() << "}}, nullptr, __stream));" << std::endl; + os() << ")}}, nullptr, __stream));" << std::endl; + endBlock(); break; } @@ -988,7 +1111,7 @@ extern "C" { for (size_t i = 0, iEnd = shape.size(); i < iEnd; i++) { os << "__ByValArray<"; } - os << CodeGenCUDA::gen(tensor->dtype()); + os << visitor.gen(tensor->dtype()); for (auto it = shape.rbegin(); it != shape.rend(); it++) { ASSERT((*it)->nodeType() == ASTNodeType::IntConst); os << ", " << (*it).as()->val_ << ">"; diff --git a/src/codegen/detail/code_gen_c.h b/src/codegen/detail/code_gen_c.h index 0f697e73f..f4eaee234 100644 --- a/src/codegen/detail/code_gen_c.h +++ b/src/codegen/detail/code_gen_c.h @@ -19,10 +19,15 @@ namespace freetensor { template std::function CodeGenC::genMdPtrType(const VarDef &def, bool isConst) { - // NOTE: `[=]` implicitly capturing `this` is deprecated in C++20, but if we - // use `[=, this]`, clang will raise a warning because it will think `this` - // is duplicated. + // NOTE: `[=]` implicitly capturing `this` is deprecated in C++20. Using + // `[=]` will trigger a warning in GCC (because of deprecation), but using + // `[=, this]` will trigger a warning in Clang<17 (because it will think + // `this` is duplicated). +#if defined(__clang__) && __clang_major__ < 17 return [=](std::ostream &os) -> std::ostream & { +#else + return [=, this](std::ostream &os) -> std::ostream & { +#endif auto &&buf = def->buffer_; if (buf->tensor()->shape().empty()) { @@ -755,7 +760,8 @@ template void CodeGenC::visit(const Eval &op) { this->os() << ";" << std::endl; } -template std::string CodeGenC::gen(DataType dtype) { +template +std::string CodeGenC::gen(const DataType &dtype) { switch (dtype.base()) { case DataType::Float64: return "double"; @@ -768,7 +774,8 @@ template std::string CodeGenC::gen(DataType dtype) { case DataType::Bool: return "bool"; default: - ASSERT(false); + throw InvalidProgram(toString(dtype) + + " is not supported by this codegen backend"); } } diff --git a/src/type/data_type.cc b/src/type/data_type.cc index b6e1884c3..c9c931a8a 100644 --- a/src/type/data_type.cc +++ b/src/type/data_type.cc @@ -11,6 +11,8 @@ size_t sizeOf(BaseDataType dtype) { case BaseDataType::Float32: case BaseDataType::Int32: return 4; + case BaseDataType::Float16: + return 2; case BaseDataType::Bool: return 1; case BaseDataType::Custom: @@ -38,6 +40,7 @@ bool isFloat(BaseDataType dtype) { case BaseDataType::Never: case BaseDataType::Float64: case BaseDataType::Float32: + case BaseDataType::Float16: return true; default: return false; diff --git a/test/40.codegen/gpu/test_gpu.py b/test/40.codegen/gpu/test_gpu.py index 2e8ed4a50..e95201d4f 100644 --- a/test/40.codegen/gpu/test_gpu.py +++ b/test/40.codegen/gpu/test_gpu.py @@ -44,6 +44,40 @@ def test(x, y): assert np.array_equal(y_np, y_std) +def test_float16_compute(): + # Not testing float16 I/O here + + @ft.transform + def test(x, y): + x: ft.Var[(4, 4), "float32", "input", "gpu/global"] + y: ft.Var[(4,), "float32", "output", "gpu/global"] + #! label: L1 + for i in range(4): + x16 = ft.empty((4,), "float16", "gpu/local") + y16 = ft.empty((), "float16", "gpu/local") + for j in range(4): + x16[j] = ft.cast(x[i, j], "float16") + y16[...] = 0 + for j in range(4): + y16[j] += x16[j] + y[i] = ft.cast(y16[...], "float32") + + with device: + s = ft.Schedule(test) + s.parallelize("L1", "threadIdx.x") + func = ft.lower(s.func(), verbose=1) + code = ft.codegen(func, verbose=True) + x_np = np.random.uniform(size=(4, 4)).astype("float32") + y_np = np.zeros((4,), dtype="float32") + x_arr = ft.array(x_np) + y_arr = ft.array(y_np) + ft.build_binary(code)(x=x_arr, y=y_arr) + y_np = y_arr.numpy() + + y_std = np.sum(x_np.astype("float16"), axis=-1).astype("float32") + assert np.all(np.isclose(y_np, y_std, atol=1e-2)) + + def test_error_wrong_target(): @ft.transform diff --git a/test/40.codegen/gpu/test_gpu_cublas.py b/test/40.codegen/gpu/test_gpu_cublas.py index da98857b1..41d72c876 100644 --- a/test/40.codegen/gpu/test_gpu_cublas.py +++ b/test/40.codegen/gpu/test_gpu_cublas.py @@ -10,7 +10,7 @@ target = device.target() -def test_basic(): +def test_float32(): @ft.transform def test(a, b, c): @@ -38,3 +38,60 @@ def test(a, b, c): c_result = c_arr.numpy() assert np.all(np.isclose(c_result, c_np + a_np @ b_np)) + + +def test_float16(): + # Not testing float16 I/O here + + @ft.transform + def test(a, b, c): + a: ft.Var[(48, 64), "float32", "input", "gpu/global"] + b: ft.Var[(64, 72), "float32", "input", "gpu/global"] + c: ft.Var[(48, 72), "float32", "inout", "gpu/global"] + a16 = ft.empty((48, 64), "float16", "gpu/global") + b16 = ft.empty((64, 72), "float16", "gpu/global") + c16 = ft.empty((48, 72), "float16", "gpu/global") + #! label: La_in + for i in range(48): + for j in range(64): + a16[i, j] = ft.cast(a[i, j], "float16") + #! label: Lb_in + for i in range(64): + for j in range(72): + b16[i, j] = ft.cast(b[i, j], "float16") + #! label: Lc_in + for i in range(48): + for j in range(72): + c16[i, j] = ft.cast(c[i, j], "float16") + #! label: L1 + for i in range(48): + for j in range(72): + for k in range(64): + c16[i, j] += a16[i, k] * b16[k, j] + #! label: Lc_out + for i in range(48): + for j in range(72): + c[i, j] = ft.cast(c16[i, j], "float32") + + s = ft.Schedule(test) + s.parallelize("La_in", "blockIdx.x") + s.parallelize("Lb_in", "blockIdx.x") + s.parallelize("Lc_in", "blockIdx.x") + s.parallelize("Lc_out", "blockIdx.x") + s.as_matmul("L1", ft.AsMatMulMode.KeepMemLayout, target, "cublas") + func = ft.lower(s.func(), target, verbose=1) + code = ft.codegen(func, target, verbose=True) + assert "cublas" in code.code + assert "CUDA_R_16F" in code.code + a_np = np.random.uniform(size=(48, 64)).astype("float32") + b_np = np.random.uniform(size=(64, 72)).astype("float32") + c_np = np.random.uniform(size=(48, 72)).astype("float32") + a_arr = ft.Array(a_np) + b_arr = ft.Array(b_np) + c_arr = ft.Array(c_np.copy()) + ft.build_binary(code, device)(a=a_arr, b=b_arr, c=c_arr) + c_result = c_arr.numpy() + + c_std = (c_np.astype("float16") + + a_np.astype("float16") @ b_np.astype("float16")).astype("float32") + assert np.all(np.isclose(c_result, c_std, atol=1e-2, rtol=1e-2)) diff --git a/test/40.codegen/gpu/test_gpu_cutlass.py b/test/40.codegen/gpu/test_gpu_cutlass.py index 7215f330f..539b2fa08 100644 --- a/test/40.codegen/gpu/test_gpu_cutlass.py +++ b/test/40.codegen/gpu/test_gpu_cutlass.py @@ -10,7 +10,7 @@ target = device.target() -def test_fp64(): +def test_float64(): @ft.transform def test(a, b, c): @@ -40,7 +40,7 @@ def test(a, b, c): assert np.all(np.isclose(c_result, c_np + a_np @ b_np)) -def test_fp32(): +def test_float32(): @ft.transform def test(a, b, c): @@ -68,3 +68,59 @@ def test(a, b, c): c_result = c_arr.numpy() assert np.all(np.isclose(c_result, c_np + a_np @ b_np)) + + +def test_float16(): + # Not testing float16 I/O here + + @ft.transform + def test(a, b, c): + a: ft.Var[(48, 64), "float32", "input", "gpu/global"] + b: ft.Var[(64, 72), "float32", "input", "gpu/global"] + c: ft.Var[(48, 72), "float32", "inout", "gpu/global"] + a16 = ft.empty((48, 64), "float16", "gpu/global") + b16 = ft.empty((64, 72), "float16", "gpu/global") + c16 = ft.empty((48, 72), "float16", "gpu/global") + #! label: La_in + for i in range(48): + for j in range(64): + a16[i, j] = ft.cast(a[i, j], "float16") + #! label: Lb_in + for i in range(64): + for j in range(72): + b16[i, j] = ft.cast(b[i, j], "float16") + #! label: Lc_in + for i in range(48): + for j in range(72): + c16[i, j] = ft.cast(c[i, j], "float16") + #! label: L1 + for i in range(48): + for j in range(72): + for k in range(64): + c16[i, j] += a16[i, k] * b16[k, j] + #! label: Lc_out + for i in range(48): + for j in range(72): + c[i, j] = ft.cast(c16[i, j], "float32") + + s = ft.Schedule(test) + s.parallelize("La_in", "blockIdx.x") + s.parallelize("Lb_in", "blockIdx.x") + s.parallelize("Lc_in", "blockIdx.x") + s.parallelize("Lc_out", "blockIdx.x") + s.as_matmul("L1", ft.AsMatMulMode.KeepMemLayout, target, "cutlass") + func = ft.lower(s.func(), target, verbose=1) + code = ft.codegen(func, target, verbose=True) + assert "cutlass" in code.code + a_np = np.random.uniform(size=(48, 64)).astype("float32") + b_np = np.random.uniform(size=(64, 72)).astype("float32") + c_np = np.random.uniform(size=(48, 72)).astype("float32") + a_arr = ft.Array(a_np) + b_arr = ft.Array(b_np) + c_arr = ft.Array(c_np.copy()) + ft.build_binary(code, device)(a=a_arr, b=b_arr, c=c_arr) + c_result = c_arr.numpy() + + c_std = (c_np.astype("float16") + + a_np.astype("float16") @ b_np.astype("float16")).astype("float32") + assert np.all(np.isclose(c_result, c_std, atol=1e-2, rtol=1e-2)) diff --git a/test/50.frontend/test_transformer_basic.py b/test/50.frontend/test_transformer_basic.py index ae62ad3ba..47ab450f8 100644 --- a/test/50.frontend/test_transformer_basic.py +++ b/test/50.frontend/test_transformer_basic.py @@ -626,3 +626,17 @@ def expect(x: ft.Var[(), "int32", "inout"]): x[...] = y3[...] assert expect.body.match(test.body) + + +def test_metadata_redundant_spaces(): + + @ft.transform + def f(x): + x: ft.Var[(4, 4), "float32", "output"] + #!label: S1 + x[2, 3] = 2.0 + #! label: S2 + x[1, 0] = 3.0 + + assert len(ft.find_all_stmt(f, "S1")) == 1 + assert len(ft.find_all_stmt(f, "S2")) == 1