-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
ttnn ops backend metal wrapper lib (#1230)
TTNNOpModelLib initial version. To be used for op model interface (constraints, l1, perf). builds with -DTTMLIR_ENABLE_OPMODEL=ON.
- Loading branch information
Showing
10 changed files
with
334 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -47,8 +47,9 @@ jobs: | |
fail-fast: false | ||
matrix: | ||
build: [ | ||
{runs-on: ubuntu-latest, enable_perf: OFF, name: "run", ttrt_flags: ""}, | ||
{runs-on: ubuntu-latest, enable_perf: ON, name: "perf", ttrt_flags: ""}, | ||
{runs-on: ubuntu-latest, enable_perf: OFF, enable_op_model: OFF, name: "run", ttrt_flags: ""}, | ||
{runs-on: ubuntu-latest, enable_perf: ON, enable_op_model: OFF, name: "perf", ttrt_flags: ""}, | ||
{runs-on: ubuntu-latest, enable_perf: OFF, enable_op_model: ON, name: "op_model" , ttrt_flags: ""} | ||
] | ||
|
||
name: Build tt-mlir | ||
|
@@ -78,7 +79,7 @@ jobs: | |
uses: hendrikmuhs/[email protected] | ||
with: | ||
create-symlink: true | ||
key: ${{ matrix.build.runs-on }}-run-ON-perf-${{ matrix.build.enable_perf }}-${{ env.SDK_VERSION }} | ||
key: ${{ matrix.build.runs-on }}-run-ON-perf-${{ matrix.build.enable_perf }}-op_model-${{ matrix.build.enable_op_model }}-${{ env.SDK_VERSION }} | ||
|
||
# Build project | ||
|
||
|
@@ -97,6 +98,7 @@ jobs: | |
-DTTMLIR_ENABLE_RUNTIME_TESTS=ON \ | ||
-DTT_RUNTIME_ENABLE_PERF_TRACE=${{ matrix.build.enable_perf }} \ | ||
-DTTMLIR_ENABLE_STABLEHLO=ON \ | ||
-DTTMLIR_ENABLE_OP_MODEL=${{ matrix.build.enable_op_model }} \ | ||
-S ${{ steps.strings.outputs.work-dir }} | ||
- name: Build | ||
|
@@ -147,7 +149,7 @@ jobs: | |
- name: Upload Test Report | ||
uses: actions/upload-artifact@v4 | ||
with: | ||
name: test-reports-${{ matrix.build.runs-on }}-perf-${{ matrix.build.enable_perf }} | ||
name: test-reports-${{ matrix.build.runs-on }}-perf-${{ matrix.build.enable_perf }}-op_model-${{ matrix.build.enable_op_model }} | ||
path: build/test/report.xml | ||
|
||
- name: Show Test Report | ||
|
@@ -480,7 +482,7 @@ jobs: | |
uses: hendrikmuhs/[email protected] | ||
with: | ||
create-symlink: true | ||
key: ${{ matrix.build.runs-on }}-run-ON-perf-${{ matrix.build.enable_perf }}-${{ env.SDK_VERSION }} | ||
key: ${{ matrix.build.runs-on }}-run-ON-perf-${{ matrix.build.enable_perf }}-op_model-${{ matrix.build.enable_op_model }}-${{ env.SDK_VERSION }} | ||
|
||
- name: Configure CMake | ||
shell: bash | ||
|
@@ -496,6 +498,7 @@ jobs: | |
-DTTMLIR_ENABLE_RUNTIME_TESTS=OFF \ | ||
-DTT_RUNTIME_ENABLE_PERF_TRACE=${{ matrix.build.enable_perf }} \ | ||
-DTTMLIR_ENABLE_STABLEHLO=OFF \ | ||
-DTTMLIR_ENABLE_OP_MODEL=${{ matrix.build.enable_op_model }} \ | ||
-S ${{ steps.strings.outputs.work-dir }} | ||
- name: Build tt-explorer | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#ifndef TTMLIR_OPMODEL_TTNN_TTNNOPMODEL_H | ||
#define TTMLIR_OPMODEL_TTNN_TTNNOPMODEL_H | ||
|
||
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" | ||
|
||
#include <tuple> | ||
|
||
namespace mlir::tt::op_model::ttnn { | ||
|
||
struct ReluOpInterface { | ||
static bool isLegal(const mlir::tt::ttnn::TTNNLayoutAttr &inputLayout, | ||
const mlir::tt::ttnn::TTNNLayoutAttr &outputLayout); | ||
|
||
static std::tuple<size_t, size_t, size_t> | ||
getOpL1Usage(const mlir::tt::ttnn::TTNNLayoutAttr &inputLayout, | ||
const mlir::tt::ttnn::TTNNLayoutAttr &outputLayout); | ||
}; | ||
|
||
} // namespace mlir::tt::op_model::ttnn | ||
#endif // TTMLIR_OPMODEL_TTNN_TTNNOPMODEL_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
add_subdirectory(TTNN) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
set(LIB_NAME TTNNOpModelLib) | ||
|
||
set(CMAKE_CXX_STANDARD 20) | ||
set(CMAKE_CXX_STANDARD_REQUIRED ON) | ||
|
||
set(SOURCES | ||
TTNNOpModelLib.cpp | ||
) | ||
add_library(${LIB_NAME} STATIC ${SOURCES}) | ||
|
||
message(STATUS "TTMLIR_ENABLE_OP_MODEL[${TTMLIR_ENABLE_OP_MODEL}]") | ||
if (TTMLIR_ENABLE_OPMODEL) | ||
# Link to tt-metal libs and include directories | ||
target_include_directories(${LIB_NAME} PUBLIC "$<BUILD_INTERFACE:${TTMETAL_INCLUDE_DIRS}>") | ||
target_link_libraries(${LIB_NAME} PUBLIC TTNN_LIBRARY TTMETAL_LIBRARY) | ||
target_compile_definitions(${LIB_NAME} PUBLIC TTMLIR_ENABLE_OPMODEL) | ||
else() | ||
# link stubs implementation when op model library is disabled | ||
message(WARNING "TTNNOpModelLib is disabled. The optimizer will not achieve optimal performance.") | ||
endif() | ||
|
||
# Specify the include directories for the library | ||
target_include_directories(${LIB_NAME} | ||
PUBLIC | ||
${CMAKE_CURRENT_SOURCE_DIR}/ | ||
${PROJECT_SOURCE_DIR}/include/ttmlir/OpModel/TTNN/) | ||
|
||
|
||
# Add TTNNOpModelLib to the export set | ||
install(TARGETS ${LIB_NAME} | ||
EXPORT TTNNOpModelLibTargets | ||
LIBRARY DESTINATION lib | ||
ARCHIVE DESTINATION lib | ||
RUNTIME DESTINATION bin | ||
INCLUDES DESTINATION include) | ||
|
||
# Export the targets | ||
export(EXPORT TTNNOpModelLibTargets | ||
FILE "${CMAKE_CURRENT_BINARY_DIR}/TTNNOpModelLibTargets.cmake" | ||
NAMESPACE TTNN::) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,183 @@ | ||
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include "TTNNOpModel.h" | ||
|
||
#ifdef TTMLIR_ENABLE_OPMODEL | ||
#include "TTNNOpModelLib_Impl.h" | ||
#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" | ||
|
||
#include <llvm/Support/Casting.h> | ||
#include <mlir/IR/AttrTypeSubElements.h> | ||
|
||
#include <cstddef> | ||
#include <stdexcept> | ||
#endif // TTMLIR_ENABLE_OPMODEL | ||
|
||
namespace mlir::tt::op_model::ttnn { | ||
|
||
#ifdef TTMLIR_ENABLE_OPMODEL | ||
// alias to a common tt_metal types | ||
using DataType = ::tt::tt_metal::DataType; | ||
using Layout = ::tt::tt_metal::Layout; | ||
using CoreRange = ::tt::tt_metal::CoreRange; | ||
using CoreRangeSet = ::tt::tt_metal::CoreRangeSet; | ||
using CoreCoord = ::tt::tt_metal::CoreCoord; | ||
using ShardSpec = ::tt::tt_metal::ShardSpec; | ||
using ShardOrientation = ::tt::tt_metal::ShardOrientation; | ||
using TensorMemoryLayout = ::tt::tt_metal::TensorMemoryLayout; | ||
using MemoryConfig = ::tt::tt_metal::MemoryConfig; | ||
|
||
namespace detail { | ||
|
||
DataType getDataType(const mlir::MemRefType &memref) { | ||
|
||
auto dataType = elementTypeToDataType(memref.getElementType()); | ||
|
||
switch (dataType) { | ||
case tt::DataType::Float32: | ||
return DataType::FLOAT32; | ||
case tt::DataType::BFloat16: | ||
return DataType::BFLOAT16; | ||
case tt::DataType::BFP_BFloat8: | ||
return DataType::BFLOAT8_B; | ||
case tt::DataType::BFP_BFloat4: | ||
return DataType::BFLOAT4_B; | ||
case tt::DataType::UInt32: | ||
return DataType::UINT32; | ||
case tt::DataType::UInt16: | ||
return DataType::UINT16; | ||
case tt::DataType::UInt8: | ||
return DataType::UINT8; | ||
default: | ||
throw std::runtime_error("Invalid element type"); | ||
} | ||
} | ||
|
||
::ttnn::SimpleShape getTensorShape(const mlir::MemRefType &memref) { | ||
::tt::tt_metal::SmallVector<uint32_t> small_vector_shape( | ||
memref.getShape().begin(), memref.getShape().end()); | ||
return ::ttnn::SimpleShape(small_vector_shape); | ||
} | ||
|
||
const std::array<uint32_t, 2> | ||
getShardShape(const mlir::tt::ttnn::TTNNLayoutAttr &layout) { | ||
const auto layoutShardTile = layout.getShardShape(); | ||
|
||
if (layoutShardTile.size() != 2) { | ||
llvm::errs() << "ERROR: layout_shard_tile.size() != 2\n"; | ||
return {0, 0}; | ||
} | ||
|
||
std::array<uint32_t, 2> shardShape; | ||
shardShape[0] = layoutShardTile[0]; | ||
shardShape[1] = layoutShardTile[1]; | ||
return shardShape; | ||
} | ||
|
||
Layout getTensorLayout(const mlir::tt::ttnn::TTNNLayoutAttr &layout) { | ||
return layout.isTiled() ? Layout::TILE : Layout::ROW_MAJOR; | ||
} | ||
|
||
CoreRangeSet getCoreRangeSet(const mlir::tt::ttnn::TTNNLayoutAttr &layout) { | ||
// TODO(mbezulj): handle more complex grid shapes | ||
// assuming grid shape is one rect starting at (0,0) | ||
|
||
const auto layoutGrid = layout.getGrid(); | ||
|
||
const auto layoutGridShape = layoutGrid.getShape(); | ||
if (layoutGridShape.size() != 2) { | ||
llvm::errs() << "ERROR: layout_grid.getShape().size() == 2\n"; | ||
return {}; | ||
} | ||
|
||
return CoreRangeSet(CoreRange(CoreCoord(0, layoutGridShape[0]), | ||
CoreCoord(0, layoutGridShape[1]))); | ||
} | ||
|
||
std::optional<ShardSpec> | ||
layout_get_shard_spec(const mlir::tt::ttnn::TTNNLayoutAttr &layout) { | ||
// tt_ShardOrientation is not part of ttnn::TTNNLayoutAttr; | ||
// defaulting to ROW_MAJOR. TODO: figure out if we need to expose this | ||
return isShardedMemoryLayout(layout.getMemLayout()) | ||
? std::make_optional(ShardSpec(getCoreRangeSet(layout), | ||
getShardShape(layout), | ||
ShardOrientation::ROW_MAJOR, false)) | ||
: std::nullopt; | ||
} | ||
|
||
::tt::tt_metal::BufferType getBufferType(const mlir::MemRefType &memref) { | ||
auto memorySpace = | ||
mlir::cast<tt::MemorySpaceAttr>(memref.getMemorySpace()).getValue(); | ||
|
||
switch (memorySpace) { | ||
case tt::MemorySpace::DeviceDRAM: | ||
return ::tt::tt_metal::BufferType::DRAM; | ||
case tt::MemorySpace::DeviceL1: | ||
return ::tt::tt_metal::BufferType::L1; | ||
default: // TODO(mbezulj): handle other memory spaces | ||
throw std::runtime_error("Unsupported memory space"); | ||
} | ||
} | ||
|
||
::tt::tt_metal::TensorMemoryLayout | ||
getTensorMemoryLayout(const mlir::tt::ttnn::TTNNLayoutAttr &layout) { | ||
auto tensorMemoryLayout = layout.getMemLayout(); | ||
|
||
switch (tensorMemoryLayout) { | ||
case mlir::tt::ttnn::TensorMemoryLayout::Interleaved: | ||
return ::tt::tt_metal::TensorMemoryLayout::INTERLEAVED; | ||
case mlir::tt::ttnn::TensorMemoryLayout::SingleBank: | ||
return ::tt::tt_metal::TensorMemoryLayout::SINGLE_BANK; | ||
case mlir::tt::ttnn::TensorMemoryLayout::HeightSharded: | ||
return ::tt::tt_metal::TensorMemoryLayout::HEIGHT_SHARDED; | ||
case mlir::tt::ttnn::TensorMemoryLayout::WidthSharded: | ||
return ::tt::tt_metal::TensorMemoryLayout::WIDTH_SHARDED; | ||
case mlir::tt::ttnn::TensorMemoryLayout::BlockSharded: | ||
return ::tt::tt_metal::TensorMemoryLayout::BLOCK_SHARDED; | ||
default: | ||
throw std::runtime_error("Unsupported tensor memory layout"); | ||
} | ||
} | ||
|
||
::tt::tt_metal::MemoryConfig | ||
getMemoryConfig(const mlir::tt::ttnn::TTNNLayoutAttr &layout) { | ||
|
||
auto tensorMemoryLayout = getTensorMemoryLayout(layout); | ||
auto bufferType = getBufferType(layout.getMemref()); | ||
|
||
auto shardSpec = layout_get_shard_spec(layout); | ||
return ::tt::tt_metal::MemoryConfig(tensorMemoryLayout, bufferType, | ||
shardSpec); | ||
} | ||
|
||
} // namespace detail | ||
#endif // TTMLIR_ENABLE_OPMODEL | ||
|
||
//===----------------------------------------------------------------------===// | ||
// ReluOp | ||
//===----------------------------------------------------------------------===// | ||
|
||
bool ReluOpInterface::isLegal( | ||
const mlir::tt::ttnn::TTNNLayoutAttr &inputLayout, | ||
const mlir::tt::ttnn::TTNNLayoutAttr &outputLayout) { | ||
|
||
#ifdef TTMLIR_ENABLE_OPMODEL | ||
return true; // to wire into tt-metal with the next uplift | ||
#else | ||
return true; | ||
#endif // TTMLIR_ENABLE_OPMODEL | ||
} | ||
|
||
std::tuple<size_t, size_t, size_t> ReluOpInterface::getOpL1Usage( | ||
const mlir::tt::ttnn::TTNNLayoutAttr &inputLayout, | ||
const mlir::tt::ttnn::TTNNLayoutAttr &outputLayout) { | ||
#ifdef TTMLIR_ENABLE_OPMODEL | ||
return std::make_tuple(0, 0, 0); // to wire into tt-metal with the next uplift | ||
#else | ||
return std::make_tuple(0, 0, 0); | ||
#endif // TTMLIR_ENABLE_OPMODEL | ||
} | ||
|
||
} // namespace mlir::tt::op_model::ttnn |
Oops, something went wrong.