From 3745a8845e1fe16d961436684a0311c9c7b79931 Mon Sep 17 00:00:00 2001 From: Marko Bezulj <156311081+mbezuljTT@users.noreply.github.com> Date: Tue, 31 Dec 2024 12:21:55 +0100 Subject: [PATCH] Enabling op model interface for constraints and L1 usage. (#1554) This PR plumbs OpModelInterface to the underlying tt-metal op queries for validation and L1 memory consumption. `TTNNOpModelInterface.td` getOpConstraints takes input(s) and output `TTNNLayoutAttr` and returns a tuple of three values: 1. A boolean indicating if the op is legal for the given input/output layouts. 2. If the op is legal, a tuple of three values representing the op memory L1 usage estimate in bytes. - The first value is the CB L1 peak allocation in bytes. - The second value is the Tensor L1 peak allocation in bytes. - The third value is the Output L1 buffer allocation in bytes. 3. If the op is illegal, a string describing the failure. `TTNNOpModelInterface.cpp` implements hooks to the _wrapper library_ 'TTNNOpModelLib' (where metal API is). Per each op, implementation takes - tensor shapes (`llvm::ArrayRef<>`) from its operands, - worker grid (used for virtual to physical cores conversion), - op specific params (like softmax dimension), and - with layouts `TTNNLayoutsAttr` and pass them to the _wrapper library_ `TTNNOpModelLib`. `TTNNOpModelLib` converts mlir structures to metal structures, and calls into underlying 'tt-metal' op interface. Underlying `tt-metal` op interface `::ttnn::graph::query_op_constraints(..)` consumes a target op (e.g. 'ttnn::relu') and it's arguments in the order of op implemented ::invoke function that we are targeting. Implemented `SingletonDeviceContext` to avoid constant opening/closing device. This class should ensure opened device is a mockup device when it's implemented on the tt-metal side (https://github.com/tenstorrent/tt-metal/issues/14000) Added 3 types of unit tests: - TestConversion - tests conversion of the MLIR to TTNN types - TestOpModelLib - tests interface to metal API - TestOpModelInterface - tests interface built in metal ops Due to differences in tt-metal and LLVM project setups (compiler standard, exceptions) these are implemented as the place Google unit test. Unlike other unit tests that are also Google unit tests but wrapped into LLVM (and invoked using llvm-lit). As these tests require TT hardware (until mockup device is implemented), changed Build tt-mlir op_model flavour to use n300 runners. Additionally, wired op model interface in the ShardSolver; mnist_sharded.mlir compiles and runs. @odjuricicTT confirmed found solution is the one we expected. Internal doc describing more details can be found [here](https://tenstorrent-my.sharepoint.com/:w:/p/mbezulj/ETC6JOzVU9dAhQjgIAiwGt8BdjbNdXmMw-fZTo7As1BVXw?e=kM7a3c) --- .../actions/build-tt-mlir-action/action.yaml | 124 ++++ .github/workflows/build-and-test.yml | 195 +++-- CMakeLists.txt | 4 +- .../ttmlir/Dialect/TT/Utils/CoreRangeSet.h | 74 ++ .../Dialect/TTNN/IR/TTNNOpModelInterface.td | 32 +- include/ttmlir/Dialect/TTNN/IR/TTNNOps.td | 16 +- .../TTNN/Utils/VirtualToPhysicalAffineMap.h | 76 ++ include/ttmlir/OpModel/TTNN/TTNNOpModel.h | 73 +- .../ttmlir/Target/Utils/MLIRToFlatbuffer.h | 54 +- .../TTNN/Analysis/LegalLayoutAnalysis.cpp | 29 +- lib/Dialect/TTNN/Analysis/ShardSolver.cpp | 74 +- lib/Dialect/TTNN/IR/TTNNOpModelInterface.cpp | 135 +++- lib/OpModel/TTNN/CMakeLists.txt | 14 +- lib/OpModel/TTNN/Conversion.cpp | 155 ++++ lib/OpModel/TTNN/Conversion.hpp | 49 ++ .../{TTNNOpModelLib_Impl.h => MetalHeaders.h} | 19 +- lib/OpModel/TTNN/SingletonDeviceContext.cpp | 26 + lib/OpModel/TTNN/SingletonDeviceContext.h | 46 ++ lib/OpModel/TTNN/TTNNOpModelLib.cpp | 431 ++++++++---- test/lit.cfg.py | 18 + .../insert_memreconfig_override.mlir | 2 +- .../test_override_reshard_edges.mlir | 2 +- .../TTNN/optimizer/mnist_sharding.mlir | 10 +- test/unittests/CMakeLists.txt | 1 + test/unittests/OpModel/CMakeLists.txt | 1 + test/unittests/OpModel/TTNN/CMakeLists.txt | 3 + .../OpModel/TTNN/Conversion/CMakeLists.txt | 25 + .../TTNN/Conversion/TestConversion.cpp | 530 ++++++++++++++ .../unittests/OpModel/TTNN/Lib/CMakeLists.txt | 24 + .../OpModel/TTNN/Lib/TestOpModelLib.cpp | 665 ++++++++++++++++++ test/unittests/OpModel/TTNN/Op/CMakeLists.txt | 24 + .../OpModel/TTNN/Op/TestOpModelInterface.cpp | 206 ++++++ test/unittests/OpModel/TTNN/OpModelFixture.h | 128 ++++ test/unittests/lit.cfg.py | 19 + 34 files changed, 2924 insertions(+), 360 deletions(-) create mode 100644 .github/actions/build-tt-mlir-action/action.yaml create mode 100644 include/ttmlir/Dialect/TT/Utils/CoreRangeSet.h create mode 100644 include/ttmlir/Dialect/TTNN/Utils/VirtualToPhysicalAffineMap.h create mode 100644 lib/OpModel/TTNN/Conversion.cpp create mode 100644 lib/OpModel/TTNN/Conversion.hpp rename lib/OpModel/TTNN/{TTNNOpModelLib_Impl.h => MetalHeaders.h} (77%) create mode 100644 lib/OpModel/TTNN/SingletonDeviceContext.cpp create mode 100644 lib/OpModel/TTNN/SingletonDeviceContext.h create mode 100644 test/unittests/OpModel/CMakeLists.txt create mode 100644 test/unittests/OpModel/TTNN/CMakeLists.txt create mode 100644 test/unittests/OpModel/TTNN/Conversion/CMakeLists.txt create mode 100644 test/unittests/OpModel/TTNN/Conversion/TestConversion.cpp create mode 100644 test/unittests/OpModel/TTNN/Lib/CMakeLists.txt create mode 100644 test/unittests/OpModel/TTNN/Lib/TestOpModelLib.cpp create mode 100644 test/unittests/OpModel/TTNN/Op/CMakeLists.txt create mode 100644 test/unittests/OpModel/TTNN/Op/TestOpModelInterface.cpp create mode 100644 test/unittests/OpModel/TTNN/OpModelFixture.h diff --git a/.github/actions/build-tt-mlir-action/action.yaml b/.github/actions/build-tt-mlir-action/action.yaml new file mode 100644 index 000000000..104d233b5 --- /dev/null +++ b/.github/actions/build-tt-mlir-action/action.yaml @@ -0,0 +1,124 @@ +name: "Build tt-mlir" +description: "Composite action for building, testing, and uploading artifacts for tt-mlir." +inputs: + enable-perf: + description: "Enable performance tracing" + required: true + enable-op-model: + description: "Enable op model interface tests" + required: true + build-name: + description: "A unique name for this build (e.g., 'run' or 'perf')" + required: true + build-output-dir: + description: "Build folder location" + required: true + install-output-dir: + description: "Install folder location" + required: true + work-dir: + description: "tt-mlir root" + required: true + test_report_path: + description: "Path to test report" + required: true + +runs: + using: "composite" + steps: + + - name: Configure CMake + shell: bash + run: | + source env/activate + cmake -G Ninja \ + -B ${{ inputs.build-output-dir }} \ + -DCMAKE_CXX_COMPILER=clang++-17 \ + -DCMAKE_C_COMPILER=clang-17 \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_INSTALL_PREFIX=${{ inputs.install-output-dir }} \ + -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ + -DTTMLIR_ENABLE_RUNTIME=ON \ + -DTTMLIR_ENABLE_RUNTIME_TESTS=ON \ + -DTT_RUNTIME_ENABLE_PERF_TRACE=${{ inputs.enable-perf }} \ + -DTTMLIR_ENABLE_STABLEHLO=ON \ + -DTTMLIR_ENABLE_OPMODEL=${{ inputs.enable-op-model }} \ + -S ${{ inputs.work-dir }} + + - name: Build and Install + shell: bash + run: | + source env/activate + cmake --build ${{ inputs.build-output-dir }} + cmake --install ${{ inputs.build-output-dir }} --component Test + + - name: Build ttrt + shell: bash + run: | + source env/activate + cmake --build ${{ inputs.build-output-dir }} -- ttrt + + - name: Generate and set system descriptor + shell: bash + if: inputs.enable-op-model == 'ON' + run: | + source env/activate + ttrt query --save-artifacts + + - name: Run tt-mlir tests + shell: bash + run: | + source env/activate + if [ -f "${{ inputs.work-dir }}/ttrt-artifacts/system_desc.ttsys" ]; then + export SYSTEM_DESC_PATH="${{ inputs.work-dir }}/ttrt-artifacts/system_desc.ttsys" + fi + cmake --build ${{ inputs.build-output-dir }} -- check-ttmlir + cp build/test/report.xml ${{ inputs.test_report_path }} + + - name: Run OpModelInterface Tests + shell: bash + if: inputs.enable-op-model == 'ON' + run: | + source env/activate + if [ -f "${{ inputs.work-dir }}/ttrt-artifacts/system_desc.ttsys" ]; then + export SYSTEM_DESC_PATH="${{ inputs.work-dir }}/ttrt-artifacts/system_desc.ttsys" + fi + ${{ inputs.build-output-dir }}/test/unittests/OpModel/TTNN/Conversion/TestConversion + ${{ inputs.build-output-dir }}/test/unittests/OpModel/TTNN/Lib/TestOpModelLib + ${{ inputs.build-output-dir }}/test/unittests/OpModel/TTNN/Op/TestOpModelInterface + + - name: Upload Test Report + uses: actions/upload-artifact@v4 + with: + name: test-reports-${{ inputs.runs-on }}-perf-${{ inputs.enable-perf }}-op_model-${{ inputs.enable-op-model }} + path: ${{ inputs.test_report_path }} + + + - name: Upload ttrt .whl + uses: actions/upload-artifact@v4 + with: + name: ttrt-whl-${{ inputs.build-name }} + path: build/runtime/tools/python/build/ttrt*.whl + + - name: Archive Install Directory + shell: bash + working-directory: ${{ inputs.install-output-dir }} + run: tar cvf artifact.tar . + + - name: Upload Install Folder + uses: actions/upload-artifact@v4 + with: + name: install-artifacts-${{ inputs.build-name }} + path: ${{ inputs.install-output-dir }}/artifact.tar + + - name: Get Latest Tag and Version + shell: bash + run: | + latest_tag=$(git describe --tags --abbrev=0) + latest_tag=${latest_tag#v} + echo "latest_tag=$latest_tag" >> $GITHUB_ENV + commit_count=$(git rev-list ${{ env.latest_tag }}..HEAD --count) + echo "commit_count=$commit_count" >> $GITHUB_ENV + version="${{ env.latest_tag }}.${{ env.commit_count }}" + echo "version=$version" >> $GITHUB_ENV + echo $version diff --git a/.github/workflows/build-and-test.yml b/.github/workflows/build-and-test.yml index 68db5d1cf..030b35c59 100644 --- a/.github/workflows/build-and-test.yml +++ b/.github/workflows/build-and-test.yml @@ -80,7 +80,7 @@ jobs: -DTTMLIR_ENABLE_RUNTIME=ON \ -DTTMLIR_ENABLE_RUNTIME_TESTS=ON \ -DTTMLIR_ENABLE_STABLEHLO=ON \ - -DTTMLIR_ENABLE_OP_MODEL=ON \ + -DTTMLIR_ENABLE_OPMODEL=ON \ -S ${{ steps.strings.outputs.work-dir }} - name: Lint @@ -99,10 +99,9 @@ jobs: build: [ {runs-on: ubuntu-latest, enable_perf: OFF, enable_op_model: OFF, name: "run", ttrt_flags: ""}, {runs-on: ubuntu-latest, enable_perf: ON, enable_op_model: OFF, name: "perf", ttrt_flags: ""}, - {runs-on: ubuntu-latest, enable_perf: OFF, enable_op_model: ON, name: "op_model" , ttrt_flags: ""} ] - name: Build tt-mlir + name: Build and test tt-mlir (compute machine) runs-on: ${{ matrix.build.runs-on }} container: @@ -142,116 +141,16 @@ jobs: create-symlink: true key: ${{ matrix.build.runs-on }}-run-ON-perf-${{ matrix.build.enable_perf }}-op_model-${{ matrix.build.enable_op_model }}-${{ env.SDK_VERSION }} - # Build project - - - name: Configure CMake - shell: bash - run: | - source env/activate - cmake -G Ninja \ - -B ${{ steps.strings.outputs.build-output-dir }} \ - -DCMAKE_CXX_COMPILER=clang++-17 \ - -DCMAKE_C_COMPILER=clang-17 \ - -DCMAKE_BUILD_TYPE=Release \ - -DCMAKE_INSTALL_PREFIX=${{ steps.strings.outputs.install-output-dir }} \ - -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ - -DTTMLIR_ENABLE_RUNTIME=ON \ - -DTTMLIR_ENABLE_RUNTIME_TESTS=ON \ - -DTT_RUNTIME_ENABLE_PERF_TRACE=${{ matrix.build.enable_perf }} \ - -DTTMLIR_ENABLE_STABLEHLO=ON \ - -DTTMLIR_ENABLE_OP_MODEL=${{ matrix.build.enable_op_model }} \ - -S ${{ steps.strings.outputs.work-dir }} - - - name: Build - shell: bash - run: | - source env/activate - cmake --build ${{ steps.strings.outputs.build-output-dir }} - cmake --install ${{ steps.strings.outputs.build-output-dir }} --component Test - - - name: Unique-ify clang-tidy fixes - shell: bash - if: failure() && steps.lint.outcome == 'failure' - run: | - source env/activate - python tools/scripts/filter-clang-tidy-fixes.py ${{ steps.strings.outputs.build-output-dir }}/clang-tidy-fixes.yaml - - - name: Clang-tidy PR Comments - uses: platisd/clang-tidy-pr-comments@a8811fa17cd6bd02c52a3791b44f9840777e396a - if: failure() && steps.lint.outcome == 'failure' + - name: Run build and test tt-mlir + uses: ./.github/actions/build-tt-mlir-action with: - # The GitHub token (or a personal access token) - github_token: ${{ secrets.GITHUB_TOKEN }} - # The path to the clang-tidy fixes generated above - clang_tidy_fixes: ${{ steps.strings.outputs.build-output-dir }}/clang-tidy-fixes.yaml - # Optionally set to true if you want the Action to request - # changes in case warnings are found - request_changes: false - # Optionally set the number of comments per review - # to avoid GitHub API timeouts for heavily loaded - # pull requests - suggestions_per_comment: 10 - python_path: "python3" - - - name: Run Test - shell: bash - run: | - source env/activate - cmake --build ${{ steps.strings.outputs.build-output-dir }} -- check-ttmlir - cp build/test/report.xml ${{ steps.strings.outputs.test_report_path }} - - - name: Upload Test Report - uses: actions/upload-artifact@v4 - with: - name: test-reports-${{ matrix.build.runs-on }}-perf-${{ matrix.build.enable_perf }}-op_model-${{ matrix.build.enable_op_model }} - path: ${{ steps.strings.outputs.test_report_path }} - - - name: Show Test Report - uses: mikepenz/action-junit-report@v4 - if: success() || failure() - with: - report_paths: ${{ steps.strings.outputs.test_report_path }} - check_name: MLIR Tests - - # Build and upload ttrt - - - name: Build ttrt - shell: bash - run: | - source env/activate - cmake --build ${{ steps.strings.outputs.build-output-dir }} -- ttrt - - - name: Upload ttrt whl - uses: actions/upload-artifact@v4 - with: - name: ttrt-whl-${{ matrix.build.name }} - path: build/runtime/tools/python/build/ttrt*.whl - - # This is needed to preserve file permissions - # https://github.com/actions/upload-artifact?tab=readme-ov-file#permission-loss - - name: 'Tar install directory' - shell: bash - working-directory: ${{ steps.strings.outputs.install-output-dir }} - run: tar cvf artifact.tar . - - - name: Upload install folder to archive - uses: actions/upload-artifact@v4 - with: - name: install-artifacts-${{ matrix.build.name }} - path: ${{ steps.strings.outputs.install-output-dir }}/artifact.tar - - - name: Get the latest tag - shell: bash - run: | - latest_tag=$(git describe --tags --abbrev=0) - latest_tag=${latest_tag#v} - echo "latest_tag=$latest_tag" >> $GITHUB_ENV - commit_count=$(git rev-list ${{ env.latest_tag }}..HEAD --count) - echo "commit_count=$commit_count" >> $GITHUB_ENV - version="${{ env.latest_tag }}.${{ env.commit_count }}" - echo "version=$version" >> $GITHUB_ENV - echo $version - + enable-perf: ${{ matrix.build.enable_perf }} + enable-op-model: ${{ matrix.build.enable_op_model }} + build-name: ${{ matrix.build.name }} + build-output-dir: ${{ steps.strings.outputs.build-output-dir }} + install-output-dir: ${{ steps.strings.outputs.install-output-dir }} + work-dir: ${{ steps.strings.outputs.work-dir }} + test_report_path: ${{ steps.strings.outputs.test_report_path }} # Run tests on TT hardware @@ -673,7 +572,7 @@ jobs: -DTTMLIR_ENABLE_RUNTIME_TESTS=OFF \ -DTT_RUNTIME_ENABLE_PERF_TRACE=${{ matrix.build.enable_perf }} \ -DTTMLIR_ENABLE_STABLEHLO=OFF \ - -DTTMLIR_ENABLE_OP_MODEL=${{ matrix.build.enable_op_model }} \ + -DTTMLIR_ENABLE_OPMODEL=${{ matrix.build.enable_op_model }} \ -S ${{ steps.strings.outputs.work-dir }} - name: Build tt-explorer @@ -688,3 +587,73 @@ jobs: source env/activate pytest tools/explorer/test/run_tests.py # collect results + + + build-ttmlir-opmodelinterface: + needs: build-image + timeout-minutes: 120 + strategy: + fail-fast: false + matrix: + build: [ + {runs-on: n300, enable_perf: OFF, enable_op_model: ON, name: "op_model" , ttrt_flags: ""} + ] + + name: Run build and test tt-mlir (TT machine) + runs-on: ${{ matrix.build.runs-on }} + + container: + image: ${{ needs.build-image.outputs.docker-image }} + options: --device /dev/tenstorrent/0 + volumes: + - /dev/hugepages:/dev/hugepages + - /dev/hugepages-1G:/dev/hugepages-1G + - /etc/udev/rules.d:/etc/udev/rules.d + - /lib/modules:/lib/modules + - /opt/tt_metal_infra/provisioning/provisioning_env:/opt/tt_metal_infra/provisioning/provisioning_env + + steps: + + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Set reusable strings + id: strings + shell: bash + env: + job-name: "Build tt-mlir (${{ matrix.build.runs-on }}, ${{ matrix.build.enable_perf }}, ${{ matrix.build.enable_op_model }}, ${{ matrix.build.name }})" + run: | + echo "work-dir=$(pwd)" >> "$GITHUB_OUTPUT" + echo "build-output-dir=$(pwd)/build" >> "$GITHUB_OUTPUT" + echo "install-output-dir=$(pwd)/install" >> "$GITHUB_OUTPUT" + + # Github job context unfortunately doesn't contain job_id, this is the workaround how to fetch it using GH API + echo "Expected job name: ${{ env.job-name }}" + JOB_ID=$(curl -s -H "Authorization: token ${{ secrets.GH_TOKEN }}" \ + "https://api.github.com/repos/${{ github.repository }}/actions/runs/${{ github.run_id }}/attempts/${{ github.run_attempt }}/jobs" | \ + jq -r '.jobs[] | select(.name | contains("${{ env.job-name }}")) | .id ') + echo "Current job id: $JOB_ID" + echo "job-id=$JOB_ID" >> "$GITHUB_OUTPUT" + echo "test_report_path=report_$JOB_ID.xml" >> "$GITHUB_OUTPUT" + + - name: Git safe dir + run: git config --global --add safe.directory ${{ steps.strings.outputs.work-dir }} + + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2 + with: + create-symlink: true + key: ${{ matrix.build.runs-on }}-run-ON-perf-${{ matrix.build.enable_perf }}-op_model-${{ matrix.build.enable_op_model }}-${{ env.SDK_VERSION }} + + # Build project + - name: Run build and test tt-mlir + uses: ./.github/actions/build-tt-mlir-action + with: + enable-perf: ${{ matrix.build.enable_perf }} + enable-op-model: ${{ matrix.build.enable_op_model }} + build-name: ${{ matrix.build.name }} + build-output-dir: ${{ steps.strings.outputs.build-output-dir }} + install-output-dir: ${{ steps.strings.outputs.install-output-dir }} + work-dir: ${{ steps.strings.outputs.work-dir }} + test_report_path: ${{ steps.strings.outputs.test_report_path }} diff --git a/CMakeLists.txt b/CMakeLists.txt index bb6bcda75..d9aa11369 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -8,7 +8,7 @@ endif() option(TT_RUNTIME_ENABLE_PERF_TRACE "Enable performance mode" OFF) option(TTMLIR_ENABLE_RUNTIME "Enable runtime" OFF) option(TTMLIR_ENABLE_STABLEHLO "Enable StableHLO support" OFF) -option(TTMLIR_ENABLE_OP_MODEL "Enable OpModel support" OFF) +option(TTMLIR_ENABLE_OPMODEL "Enable OpModel support" OFF) option(TTMLIR_ENABLE_SHARED_LIB "Enable Shared lib building" ON) if (NOT TTMLIR_ENABLE_RUNTIME) @@ -27,7 +27,7 @@ set(CMAKE_EXPORT_COMPILE_COMMANDS ON) set(TTMLIR_ENABLE_BINDINGS_PYTHON ON CACHE BOOL "Enable Python bindings") if (APPLE) - set(TTMLIR_ENABLE_OP_MODEL OFF) + set(TTMLIR_ENABLE_OPMODEL OFF) message(WARNING "TTNNOpModelLib is disabled on Apple platforms. Optimizer will not get true performance.") endif() diff --git a/include/ttmlir/Dialect/TT/Utils/CoreRangeSet.h b/include/ttmlir/Dialect/TT/Utils/CoreRangeSet.h new file mode 100644 index 000000000..f35b35c54 --- /dev/null +++ b/include/ttmlir/Dialect/TT/Utils/CoreRangeSet.h @@ -0,0 +1,74 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#ifndef TTMLIR_DIALECT_TT_UTILS_CORERANGESET_H +#define TTMLIR_DIALECT_TT_UTILS_CORERANGESET_H + +#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" +#include "ttmlir/Utils.h" + +#include + +namespace mlir::tt::utils { + +using locSize2d = std::tuple, // {{locX, locY}, + std::array>; // {sizeX, sizeY}} + +/// Converts a virtual grid to a set of core ranges on a device grid. +/// +/// This function takes a virtual grid, and maps the virtual +/// grid coordinates to the device grid coordinates using the provided virtual +/// grid affine mapping. It then generates a set of core ranges, where each core +/// range is represented by a starting location and a size. The function merges +/// adjacent core ranges to form larger ranges when possible. +/// +/// \param virtualGrid The virtual grid attributes. +/// \returns A vector of core ranges, where each core range is represented by +/// a pair of location and size (both 2D). +inline std::vector +toCoreRangeSet(const llvm::ArrayRef virtualGridShape, + const mlir::AffineMap mapping) { + std::vector coreRangeSet; + ::ttmlir::utils::sample( + virtualGridShape, [&](ArrayRef virtualCoreCoord) { + llvm::SmallVector coreCoord = + mapping.compose(virtualCoreCoord); + assert(coreCoord.size() == PhysGridResultIdx::NumIndices && + "expected a 2D core"); + assert(coreCoord[PhysGridResultIdx::DeviceIdx] == 0 && + "expected single device"); + + if (!coreRangeSet.empty() && + ((std::get<0>(coreRangeSet.back())[1] == + coreCoord[PhysGridResultIdx::CoreCoordY]) && + (std::get<0>(coreRangeSet.back())[0] + + std::get<1>(coreRangeSet.back())[0]) == + coreCoord[PhysGridResultIdx::CoreCoordX])) { + const auto &[loc, size] = coreRangeSet.back(); + coreRangeSet.back() = {loc, {size[0] + 1, size[1]}}; + } else { + coreRangeSet.push_back( + {{static_cast(coreCoord[PhysGridResultIdx::CoreCoordX]), + static_cast(coreCoord[PhysGridResultIdx::CoreCoordY])}, + {1, 1}}); + } + if (coreRangeSet.size() > 1) { + const auto &[locPrev, sizePrev] = + coreRangeSet[coreRangeSet.size() - 2]; + const auto &[loc, size] = coreRangeSet.back(); + if ((locPrev[0] == loc[0]) && (sizePrev[0] == size[0]) && + ((locPrev[1] + sizePrev[1]) == loc[1])) { + assert(size[1] == 1); + coreRangeSet[coreRangeSet.size() - 2] = { + locPrev, {sizePrev[0], sizePrev[1] + 1}}; + coreRangeSet.pop_back(); + } + } + }); + return coreRangeSet; +} + +} // namespace mlir::tt::utils + +#endif // TTMLIR_DIALECT_TT_UTILS_CORERANGESET_H diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOpModelInterface.td b/include/ttmlir/Dialect/TTNN/IR/TTNNOpModelInterface.td index 4851d08f4..91753189b 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOpModelInterface.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOpModelInterface.td @@ -7,6 +7,7 @@ include "mlir/IR/OpBase.td" +// TODO(odjuricic): support ops with multiple outputs def TTNN_OpModelInterface : OpInterface<"OpModel"> { let description = [{ Interface to access TTNN op model methods. @@ -19,32 +20,25 @@ def TTNN_OpModelInterface : OpInterface<"OpModel"> { }], /*retTy=*/"size_t", /*methodName=*/"getOpPerfCycles", - /*args=*/(ins "const std::vector&":$input_layouts, "const TTNNLayoutAttr&":$output_layout), + /*args=*/(ins "const std::vector&":$inputs, "const TTNNLayoutAttr&":$output), /*methodBody=*/"", /*defaultImplementation=*/"return std::numeric_limits::max();" >, InterfaceMethod< /*desc=*/[{ - Returns the op memory L1 usage estimate in bytes. The return value is a tuple of 3 values: - - The first value is CB L1 peak allocation in bytes. - - The second value is Tensor L1 peak allocation in bytes. - - The third value is Output L1 buffer allocation in bytes. + Returns a tuple of three values:** + 1. A boolean indicating if the op is legal for the given input/output layouts. + 2. If the op is legal, a tuple of three values representing the op memory L1 usage estimate in bytes. + - The first value is the CB L1 peak allocation in bytes. + - The second value is the Tensor L1 peak allocation in bytes. + - The third value is the Output L1 buffer allocation in bytes. + 3. If the op is illegal, a string describing the failure. }], - /*retTy=*/"std::tuple", - /*methodName=*/"getOpL1Usage", - /*args=*/(ins "const std::vector&":$input_layouts, "const TTNNLayoutAttr&":$output_layout), + /*retTy=*/"std::tuple>, std::optional>", + /*methodName=*/"getOpConstraints", + /*args=*/(ins "const std::vector&":$inputs, "const TTNNLayoutAttr&":$output), /*methodBody=*/"", - /*defaultImplementation=*/"return std::make_tuple(0,0,0);" - >, - InterfaceMethod< - /*desc=*/[{ - Returns if input/output layouts are legal for the op. - }], - /*retTy=*/"bool", - /*methodName=*/"isOpLegal", - /*args=*/(ins "const std::vector&":$input_layouts, "const TTNNLayoutAttr&":$output_layout), - /*methodBody=*/"", - /*defaultImplementation=*/"return true;" + /*defaultImplementation=*/"return std::make_tuple(true,std::make_tuple(0,0,0), std::nullopt);" >, ]; } diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td index 8ef3f73ca..016618511 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td @@ -288,7 +288,7 @@ def TTNN_ReciprocalOp : TTNN_ElementwiseUnaryOp<"reciprocal"> { } def TTNN_ReluOp : TTNN_ElementwiseUnaryOp<"relu", - [DeclareOpInterfaceMethods] + [DeclareOpInterfaceMethods] > { let summary = "Eltwise ReLU."; let description = [{ @@ -397,7 +397,10 @@ def TTNN_LeakyReluOp : TTNN_ElementwiseUnaryWithFloatParameterOp<"leaky_relu"> { }]; } -def TTNN_AddOp : TTNN_ElementwiseBinaryOp<"add"> { +def TTNN_AddOp : TTNN_ElementwiseBinaryOp<"add", + [DeclareOpInterfaceMethods] + > { + let summary = "Eltwise add."; let description = [{ Eltwise add operation. @@ -677,7 +680,10 @@ def TTNN_EmbeddingBackwardOp : TTNN_NamedDPSOp<"embedding_bw"> { let hasVerifier = 1; } -def TTNN_SoftmaxOp : TTNN_Op<"softmax"> { +def TTNN_SoftmaxOp : TTNN_Op<"softmax", + [DeclareOpInterfaceMethods] + > { + let summary = "Softmax op."; let description = [{ Softmax operation. @@ -790,7 +796,9 @@ def TTNN_LinearOp : TTNN_NamedDPSOp<"linear"> { // ANCHOR: adding_an_op_matmul_ttnn -def TTNN_MatmulOp : TTNN_NamedDPSOp<"matmul"> { +def TTNN_MatmulOp : TTNN_NamedDPSOp<"matmul", + [DeclareOpInterfaceMethods] + > { let arguments = (ins AnyRankedTensor:$a, AnyRankedTensor:$b, AnyRankedTensor:$output); diff --git a/include/ttmlir/Dialect/TTNN/Utils/VirtualToPhysicalAffineMap.h b/include/ttmlir/Dialect/TTNN/Utils/VirtualToPhysicalAffineMap.h new file mode 100644 index 000000000..b9112af31 --- /dev/null +++ b/include/ttmlir/Dialect/TTNN/Utils/VirtualToPhysicalAffineMap.h @@ -0,0 +1,76 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#ifndef TTMLIR_DIALECT_TTNN_UTILS_VIRTUALTOPHYSICALAFFINEMAP_H +#define TTMLIR_DIALECT_TTNN_UTILS_VIRTUALTOPHYSICALAFFINEMAP_H + +#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" + +#include "mlir/IR/AffineExpr.h" + +namespace mlir::tt::ttnn::utils { + +/// Creates an affine map that translates a virtual grid layout to a physical +/// grid layout for a single device based on the specified tensor memory layout. +/// +/// This function supports three types of tensor memory layouts: +/// - WidthSharded: Maps a width-sharded virtual grid (1xN) to a physical grid +/// with the specified shape. +/// - HeightSharded: Maps a height-sharded virtual grid (Mx1) to a physical grid +/// with the specified shape. +/// - BlockSharded: Maps a block-sharded virtual grid (MxN) directly to a +/// physical grid with the specified shape. +/// +/// \param context The MLIR context. +/// \param tensorMemoryLayout The tensor memory layout type. +/// \param physicalGridShape The shape of the physical grid, defaults to {8, 8}. +/// +/// \return An affine map that translates the virtual grid layout to the +/// physical grid layout based on the specified tensor memory layout. +AffineMap CreateSingleDeviceVirtualToPhysicalAffineMap( + MLIRContext *context, + const mlir::tt::ttnn::TensorMemoryLayout &tensorMemoryLayout, + const llvm::ArrayRef physicalGridShape = {8, 8}) { + + AffineExpr workerDeviceIdx = mlir::getAffineConstantExpr(0, context); + + switch (tensorMemoryLayout) { + case mlir::tt::ttnn::TensorMemoryLayout::WidthSharded: { + // create affine map that maps width sharded virtual grid 1xN to the + // physical grid gridShape[0] x gridShape[1] + AffineExpr virtualWidth = mlir::getAffineDimExpr(1, context); // d1 + AffineExpr workerCoreW = + mlir::getAffineConstantExpr(physicalGridShape[1], context); + AffineMap widthMap = mlir::AffineMap::get( + /*dimCount=*/2, /*symbolCount=*/0, + {workerDeviceIdx, virtualWidth.floorDiv(workerCoreW), + virtualWidth % workerCoreW}, + context); + return widthMap; + } + case mlir::tt::ttnn::TensorMemoryLayout::HeightSharded: { + // create affine map that maps height sharded virtual grid Mx1 to the + // physical grid gridShape[0] x gridShape[1] + AffineExpr virtualHeight = mlir::getAffineDimExpr(0, context); // d0 + AffineExpr workerCoreW = + mlir::getAffineConstantExpr(physicalGridShape[1], context); + AffineMap heightMap = mlir::AffineMap::get( + /*dimCount=*/2, /*symbolCount=*/0, + {workerDeviceIdx, virtualHeight.floorDiv(workerCoreW), + virtualHeight % workerCoreW}, + context); + return heightMap; + } + default: + case mlir::tt::ttnn::TensorMemoryLayout::BlockSharded: { + AffineExpr d0 = mlir::getAffineDimExpr(0, context); // d0 + AffineExpr d1 = mlir::getAffineDimExpr(1, context); // d1 + AffineMap blockMap = mlir::AffineMap::get( + /*dimCount=*/2, /*symbolCount=*/0, {workerDeviceIdx, d0, d1}, context); + return blockMap; + } + } +} +} // namespace mlir::tt::ttnn::utils +#endif // TTMLIR_DIALECT_TTNN_UTILS_VIRTUALTOPHYSICALAFFINEMAP_H diff --git a/include/ttmlir/OpModel/TTNN/TTNNOpModel.h b/include/ttmlir/OpModel/TTNN/TTNNOpModel.h index 31ac14984..3e2a546ef 100644 --- a/include/ttmlir/OpModel/TTNN/TTNNOpModel.h +++ b/include/ttmlir/OpModel/TTNN/TTNNOpModel.h @@ -7,18 +7,77 @@ #include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" +#include "llvm/ADT/ArrayRef.h" + #include namespace mlir::tt::op_model::ttnn { -struct ReluOpInterface { - static bool isLegal(const mlir::tt::ttnn::TTNNLayoutAttr &inputLayout, - const mlir::tt::ttnn::TTNNLayoutAttr &outputLayout); +//===----------------------------------------------------------------------===// +// Device +//===----------------------------------------------------------------------===// + +namespace Device { +std::tuple> +getDeviceConstraints(const mlir::tt::GridAttr &workerGrid); +}; // namespace Device + +//===----------------------------------------------------------------------===// +// ReluOp +//===----------------------------------------------------------------------===// + +namespace ReluOpInterface { +std::tuple>, + std::optional> +getOpConstraints(const llvm::ArrayRef &inputShape, + const mlir::tt::ttnn::TTNNLayoutAttr &inputLayout, + const llvm::ArrayRef &outputShape, + const mlir::tt::ttnn::TTNNLayoutAttr &outputLayout); +}; // namespace ReluOpInterface + +//===----------------------------------------------------------------------===// +// AddOp +//===----------------------------------------------------------------------===// + +namespace AddOpInterface { +std::tuple>, + std::optional> +getOpConstraints(const llvm::ArrayRef &inputShape_a, + const mlir::tt::ttnn::TTNNLayoutAttr &inputLayout_a, + const llvm::ArrayRef &inputShape_b, + const mlir::tt::ttnn::TTNNLayoutAttr &inputLayout_b, + const llvm::ArrayRef &outputShape, + const mlir::tt::ttnn::TTNNLayoutAttr &outputLayout); +}; // namespace AddOpInterface + +//===----------------------------------------------------------------------===// +// SoftmaxOp +//===----------------------------------------------------------------------===// + +namespace SoftmaxOpInterface { +std::tuple>, + std::optional> +getOpConstraints(const llvm::ArrayRef &inputShape, + const mlir::tt::ttnn::TTNNLayoutAttr &inputLayout, + const int dim_arg, const llvm::ArrayRef &outputShape, + const mlir::tt::ttnn::TTNNLayoutAttr &outputLayout); +}; // namespace SoftmaxOpInterface + +//===----------------------------------------------------------------------===// +// MatmulOp +//===----------------------------------------------------------------------===// - static std::tuple - getOpL1Usage(const mlir::tt::ttnn::TTNNLayoutAttr &inputLayout, - const mlir::tt::ttnn::TTNNLayoutAttr &outputLayout); -}; +namespace MatmulOpInterface { +std::tuple>, + std::optional> +getOpConstraints(const llvm::ArrayRef &inputShape_a, + const mlir::tt::ttnn::TTNNLayoutAttr &inputLayout_a, + const llvm::ArrayRef &inputShape_b, + const mlir::tt::ttnn::TTNNLayoutAttr &inputLayout_b, + const llvm::ArrayRef &outputShape, + const mlir::tt::ttnn::TTNNLayoutAttr &outputLayout, + bool transpose_a, bool transpose_b); +}; // namespace MatmulOpInterface } // namespace mlir::tt::op_model::ttnn #endif // TTMLIR_OPMODEL_TTNN_TTNNOPMODEL_H diff --git a/include/ttmlir/Target/Utils/MLIRToFlatbuffer.h b/include/ttmlir/Target/Utils/MLIRToFlatbuffer.h index ec8a2571e..3b365e748 100644 --- a/include/ttmlir/Target/Utils/MLIRToFlatbuffer.h +++ b/include/ttmlir/Target/Utils/MLIRToFlatbuffer.h @@ -6,6 +6,7 @@ #define TTMLIR_TARGET_UTILS_MLIRTOFLATBUFFER_H #include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" +#include "ttmlir/Dialect/TT/Utils/CoreRangeSet.h" #include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" #include "ttmlir/Target/Common/Target.h" #include "ttmlir/Target/Utils/FlatbufferObjectCache.h" @@ -368,47 +369,18 @@ inline std::vector<::tt::target::Dim2dRange> toFlatbuffer(FlatbufferObjectCache &cache, GridAttr tensorGrid, GridAttr deviceGrid) { std::vector<::tt::target::Dim2dRange> coreRangeSet; - SmallVector tensorGridShape(tensorGrid.getShape()); - AffineMap mapping = deviceGrid.getMapping(); - ::ttmlir::utils::sample( - tensorGridShape, [&](ArrayRef virtualCoreCoord) { - SmallVector coreCoord = mapping.compose(virtualCoreCoord); - assert(coreCoord.size() == PhysGridResultIdx::NumIndices && - "expected a 2D core"); - assert(coreCoord[PhysGridResultIdx::DeviceIdx] == 0 && - "expected single device"); - if (!coreRangeSet.empty() && - ((coreRangeSet.back().loc().y() == - coreCoord[PhysGridResultIdx::CoreCoordY]) && - (coreRangeSet.back().loc().x() + coreRangeSet.back().size().x()) == - coreCoord[PhysGridResultIdx::CoreCoordX])) { - coreRangeSet.back() = ::tt::target::Dim2dRange( - coreRangeSet.back().loc(), - ::tt::target::Dim2d(coreRangeSet.back().size().y(), - coreRangeSet.back().size().x() + 1)); - } else { - coreRangeSet.push_back(::tt::target::Dim2dRange( - ::tt::target::Dim2d(coreCoord[PhysGridResultIdx::CoreCoordY], - coreCoord[PhysGridResultIdx::CoreCoordX]), - ::tt::target::Dim2d(1, 1))); - } - if (coreRangeSet.size() > 1 && - (coreRangeSet[coreRangeSet.size() - 2].loc().x() == - coreRangeSet.back().loc().x()) && - (coreRangeSet[coreRangeSet.size() - 2].size().x() == - coreRangeSet.back().size().x()) && - ((coreRangeSet[coreRangeSet.size() - 2].loc().y() + - coreRangeSet[coreRangeSet.size() - 2].size().y()) == - coreRangeSet.back().loc().y())) { - assert(coreRangeSet.back().size().y() == 1); - coreRangeSet[coreRangeSet.size() - 2] = ::tt::target::Dim2dRange( - coreRangeSet[coreRangeSet.size() - 2].loc(), - ::tt::target::Dim2d( - coreRangeSet[coreRangeSet.size() - 2].size().y() + 1, - coreRangeSet[coreRangeSet.size() - 2].size().x())); - coreRangeSet.pop_back(); - } - }); + + auto mapping = (tensorGrid.getMapping().isEmpty() == true) + ? deviceGrid.getMapping() + : tensorGrid.getMapping(); + for (const auto &locsize2d : + utils::toCoreRangeSet(tensorGrid.getShape(), mapping)) { + const auto &[loc, size] = locsize2d; + coreRangeSet.push_back( + ::tt::target::Dim2dRange(::tt::target::Dim2d(loc[1], loc[0]), + ::tt::target::Dim2d(size[1], size[0]))); + } + return coreRangeSet; } diff --git a/lib/Dialect/TTNN/Analysis/LegalLayoutAnalysis.cpp b/lib/Dialect/TTNN/Analysis/LegalLayoutAnalysis.cpp index 799ef6c5c..7669ea7b6 100644 --- a/lib/Dialect/TTNN/Analysis/LegalLayoutAnalysis.cpp +++ b/lib/Dialect/TTNN/Analysis/LegalLayoutAnalysis.cpp @@ -9,6 +9,7 @@ #include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" #include "ttmlir/Dialect/TTNN/Utils/OptimizerOverrides.h" #include "ttmlir/Dialect/TTNN/Utils/Utils.h" +#include "ttmlir/Dialect/TTNN/Utils/VirtualToPhysicalAffineMap.h" namespace mlir::tt::ttnn { @@ -248,13 +249,18 @@ void LegalLayoutAnalysis::analysisImplementation() { assert(analysisInput.maxGrid.getShape().size() == 2 && "Max device grid is expected to be 2D."); // Block Sharded + auto affineMapBs = + mlir::tt::ttnn::utils::CreateSingleDeviceVirtualToPhysicalAffineMap( + op->getContext(), TensorMemoryLayout::BlockSharded, + analysisInput.maxGrid.getShape()); for (int width = 1; width <= analysisInput.maxGrid.getShape()[0]; ++width) { for (int height = 1; height <= analysisInput.maxGrid.getShape()[1]; ++height) { shardedResults.push_back( shardedBase .withGrid(op->getContext(), tensorType, - GridAttr::get(op->getContext(), {width, height})) + GridAttr::get(op->getContext(), {width, height}, + affineMapBs)) .withMemoryLayout(op->getContext(), TensorMemoryLayout::BlockSharded)); } @@ -262,23 +268,32 @@ void LegalLayoutAnalysis::analysisImplementation() { int64_t numCores = analysisInput.maxGrid.getGridVolume(); // Height Sharded - // TODO(odjuricic): Missing affine mapping to actual grid. Need to check - // with runtime implementation on what to produce here. + auto affineMapHs = + mlir::tt::ttnn::utils::CreateSingleDeviceVirtualToPhysicalAffineMap( + op->getContext(), TensorMemoryLayout::HeightSharded, + analysisInput.maxGrid.getShape()); + for (int height = 1; height <= numCores; ++height) { shardedResults.push_back( shardedBase - .withGrid(op->getContext(), tensorType, - GridAttr::get(op->getContext(), {height, 1})) + .withGrid( + op->getContext(), tensorType, + GridAttr::get(op->getContext(), {height, 1}, affineMapHs)) .withMemoryLayout(op->getContext(), TensorMemoryLayout::HeightSharded)); } // Width Sharded + auto affineMapWs = + mlir::tt::ttnn::utils::CreateSingleDeviceVirtualToPhysicalAffineMap( + op->getContext(), TensorMemoryLayout::WidthSharded, + analysisInput.maxGrid.getShape()); for (int width = 1; width <= numCores; ++width) { shardedResults.push_back( shardedBase - .withGrid(op->getContext(), tensorType, - GridAttr::get(op->getContext(), {1, width})) + .withGrid( + op->getContext(), tensorType, + GridAttr::get(op->getContext(), {1, width}, affineMapWs)) .withMemoryLayout(op->getContext(), TensorMemoryLayout::WidthSharded)); } diff --git a/lib/Dialect/TTNN/Analysis/ShardSolver.cpp b/lib/Dialect/TTNN/Analysis/ShardSolver.cpp index 6e19b5e62..5839a6e73 100644 --- a/lib/Dialect/TTNN/Analysis/ShardSolver.cpp +++ b/lib/Dialect/TTNN/Analysis/ShardSolver.cpp @@ -520,10 +520,74 @@ bool ShardSolver::checkShardCompatible( // if (OpModel backend = dyn_cast(consumerOp)) { - if (false == - backend.isOpLegal(std::vector{producerLayout}, consumerLayout)) { + + auto deviceAttr = mlir::tt::getCurrentScopeDevice(producerOp); + assert(deviceAttr); + auto workerGrid = deviceAttr.getWorkerGrid(); + + // Map consumer operands to DRAM interleave or provided producerLayout + // only one operand can be mapped to producerLayout, it's picked as first + // operand matching producerOp output shape. + + uint32_t numOperands = consumerOp->getNumOperands(); + + // Some ops have multiple operands; and some ops have output also an + // operand. TBD if there is a more robust way to get real number of inputs. + // TODO(odjuricic): cast to DPSop? + numOperands = (numOperands > 1) ? numOperands - 1 : numOperands; + std::vector inputLayouts; + + auto inputUnderCheck = + mlir::cast(producerOp->getResult(0).getType()); + bool inputUnderCheckFound = false; + + for (uint32_t i = 0; i < numOperands; i++) { + auto operand = consumerOp->getOperand(i); + auto input = mlir::cast(operand.getType()); + + if ((inputUnderCheckFound == false) && + (inputUnderCheck.getShape() == input.getShape())) { + // this is the input we are checking compatibility for + inputUnderCheckFound = true; + inputLayouts.push_back(producerLayout); + } else { + // this is the other input that we DRAM interleave + + // what if it is tilized already? + auto elementType = + TileType::get(consumerOp->getContext(), input.getElementType()); + + auto layout = TTNNLayoutAttr::get( + consumerOp->getContext(), input.getShape(), elementType, + BufferType::DRAM, workerGrid, + TensorMemoryLayoutAttr::get(consumerOp->getContext(), + TensorMemoryLayout::Interleaved)); + inputLayouts.push_back(layout); + } + } + + auto [legal, l1Usage, errorMsg] = + backend.getOpConstraints(inputLayouts, consumerLayout); + + constexpr bool debug = false; + if (false == legal) { + // early exit + if (debug) { + llvm::errs() << "OpModel constraints failed: "; + llvm::errs() << producerOp->getName() << "->" << consumerOp->getName() + << " :: " << errorMsg.value() << "\n"; + producerLayout.dump(); + consumerLayout.dump(); + } return false; } + if (debug) { + llvm::errs() << "OpModel constraints valid. "; + llvm::errs() << producerOp->getName() << "->" << consumerOp->getName() + << "\n"; + producerLayout.dump(); + consumerLayout.dump(); + } } // May need to fetch other inputs for consumerOp(weights/join node). @@ -565,12 +629,6 @@ bool ShardSolver::checkShardCompatible( } } - // Shard compat assumption. Try to keep same shard layout. - // - if (producerLayout.getMemLayout() != consumerLayout.getMemLayout()) { - return false; - } - return true; } diff --git a/lib/Dialect/TTNN/IR/TTNNOpModelInterface.cpp b/lib/Dialect/TTNN/IR/TTNNOpModelInterface.cpp index 344a4a483..9f40df46d 100644 --- a/lib/Dialect/TTNN/IR/TTNNOpModelInterface.cpp +++ b/lib/Dialect/TTNN/IR/TTNNOpModelInterface.cpp @@ -5,36 +5,141 @@ #include "ttmlir/Dialect/TTNN/IR/TTNNOps.h" #include "ttmlir/Dialect/TTNN/IR/TTNNOpModelInterface.cpp.inc" +#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h" +#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" #include "ttmlir/OpModel/TTNN/TTNNOpModel.h" +#include "mlir/IR/Operation.h" + #include +#include #include namespace mlir::tt::ttnn { +namespace detail { +std::tuple>, + std::optional> +checkDeviceWorkerGrid(mlir::Operation *op) { + + auto deviceAttr = mlir::tt::getCurrentScopeDevice(op); + assert(deviceAttr); + auto checkWorkerGrid = + op_model::ttnn::Device::getDeviceConstraints(deviceAttr.getWorkerGrid()); + + if (std::get<0>(checkWorkerGrid) == false) { + return std::make_tuple(std::get<0>(checkWorkerGrid), std::nullopt, + std::get<1>(checkWorkerGrid)); + } + + return std::make_tuple(true, std::nullopt, std::nullopt); +} +} // namespace detail + //===----------------------------------------------------------------------===// // ReluOp - TTNN Op Model Interface //===----------------------------------------------------------------------===// -size_t ReluOp::getOpPerfCycles(const std::vector &input_layouts, - const TTNNLayoutAttr &output_layout) { - // TODO(mbezulj) wire to tt-metal once we have API - return 5; +std::tuple>, + std::optional> +ReluOp::getOpConstraints(const std::vector &inputs, + const TTNNLayoutAttr &output) { + assert(inputs.size() == 1); + + const auto input_shape = + mlir::cast(getDpsInputOperand(0)->get().getType()) + .getShape(); + + const auto output_shape = + mlir::cast(getResults().front().getType()).getShape(); + + auto check = detail::checkDeviceWorkerGrid(getOperation()); + if (std::get(check) == false) { + return check; + } + + return op_model::ttnn::ReluOpInterface::getOpConstraints( + input_shape, inputs[0], output_shape, output); +} + +//===----------------------------------------------------------------------===// +// AddOp - TTNN Op Model Interface +//===----------------------------------------------------------------------===// + +std::tuple>, + std::optional> +AddOp::getOpConstraints(const std::vector &inputs, + const TTNNLayoutAttr &output) { + assert(inputs.size() == 2); + + const auto input_shape_a = + mlir::cast(getOperand(0).getType()).getShape(); + const auto input_shape_b = + mlir::cast(getOperand(1).getType()).getShape(); + + const auto output_shape = + mlir::cast(getResult(0).getType()).getShape(); + + auto check = detail::checkDeviceWorkerGrid(getOperation()); + if (std::get(check) == false) { + return check; + } + + return op_model::ttnn::AddOpInterface::getOpConstraints( + input_shape_a, inputs[0], input_shape_b, inputs[1], output_shape, output); } -std::tuple -ReluOp::getOpL1Usage(const std::vector &input_layouts, - const TTNNLayoutAttr &output_layout) { - assert(input_layouts.size() == 1); - return op_model::ttnn::ReluOpInterface::getOpL1Usage(input_layouts[0], - output_layout); +//===----------------------------------------------------------------------===// +// SoftmaxOp - TTNN Op Model Interface +//===----------------------------------------------------------------------===// + +std::tuple>, + std::optional> +SoftmaxOp::getOpConstraints(const std::vector &inputs, + const TTNNLayoutAttr &output) { + assert(inputs.size() == 1); + + const auto input_shape = + mlir::cast(getOperand().getType()).getShape(); + + const auto output_shape = + mlir::cast(getResult().getType()).getShape(); + + auto check = detail::checkDeviceWorkerGrid(getOperation()); + if (std::get(check) == false) { + return check; + } + + return op_model::ttnn::SoftmaxOpInterface::getOpConstraints( + input_shape, inputs[0], getDimension(), output_shape, output); } -bool ReluOp::isOpLegal(const std::vector &input_layouts, - const TTNNLayoutAttr &output_layout) { - assert(input_layouts.size() == 1); - return op_model::ttnn::ReluOpInterface::isLegal(input_layouts[0], - output_layout); +//===----------------------------------------------------------------------===// +// MatmulOp - TTNN Op Model Interface +//===----------------------------------------------------------------------===// + +std::tuple>, + std::optional> +MatmulOp::getOpConstraints(const std::vector &inputs, + const TTNNLayoutAttr &output) { + assert(inputs.size() == 2); + + const auto input_shape_a = + mlir::cast(getOperand(0).getType()).getShape(); + const auto input_shape_b = + mlir::cast(getOperand(1).getType()).getShape(); + + const auto output_shape = + mlir::cast(getResult().getType()).getShape(); + + auto check = detail::checkDeviceWorkerGrid(getOperation()); + if (std::get(check) == false) { + return check; + } + + return op_model::ttnn::MatmulOpInterface::getOpConstraints( + input_shape_a, inputs[0], input_shape_b, inputs[1], output_shape, output, + false, false); } } // namespace mlir::tt::ttnn diff --git a/lib/OpModel/TTNN/CMakeLists.txt b/lib/OpModel/TTNN/CMakeLists.txt index 094b9f1dd..1210c7ebf 100644 --- a/lib/OpModel/TTNN/CMakeLists.txt +++ b/lib/OpModel/TTNN/CMakeLists.txt @@ -5,14 +5,24 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON) set(SOURCES TTNNOpModelLib.cpp + Conversion.cpp + SingletonDeviceContext.cpp ) add_library(${LIB_NAME} STATIC ${SOURCES}) -message(STATUS "TTMLIR_ENABLE_OP_MODEL[${TTMLIR_ENABLE_OP_MODEL}]") +message(STATUS "TTMLIR_ENABLE_OPMODEL[${TTMLIR_ENABLE_OPMODEL}]") if (TTMLIR_ENABLE_OPMODEL) + # Building op model library will invoke building tt-metal; and it requires TT_METAL_HOME and ARCH_NAME environment variables to be set. + if("$ENV{TT_METAL_HOME}" STREQUAL "") + message(FATAL_ERROR "TT_METAL_HOME is not set") + endif() + if("$ENV{ARCH_NAME}" STREQUAL "") + message(FATAL_ERROR "ARCH_NAME is not set") + endif() + # Link to tt-metal libs and include directories target_include_directories(${LIB_NAME} PUBLIC "$") - target_link_libraries(${LIB_NAME} PUBLIC TTNN_LIBRARY TTMETAL_LIBRARY) + target_link_libraries(${LIB_NAME} PUBLIC TTNN_LIBRARY TTMETAL_LIBRARY DEVICE_LIBRARY) target_compile_definitions(${LIB_NAME} PUBLIC TTMLIR_ENABLE_OPMODEL) else() # link stubs implementation when op model library is disabled diff --git a/lib/OpModel/TTNN/Conversion.cpp b/lib/OpModel/TTNN/Conversion.cpp new file mode 100644 index 000000000..ce9ab2c9b --- /dev/null +++ b/lib/OpModel/TTNN/Conversion.cpp @@ -0,0 +1,155 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#ifdef TTMLIR_ENABLE_OPMODEL +#include "Conversion.hpp" + +#include "ttmlir/Dialect/TT/Utils/CoreRangeSet.h" +#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" + +#include "llvm/ADT/ArrayRef.h" + +namespace mlir::tt::op_model::ttnn { + +namespace conversion { + +::tt::tt_metal::DataType +getDataType(const mlir::tt::ttnn::TTNNLayoutAttr layout) { + auto dataType = layout.getDataType(); + + switch (dataType) { + case tt::DataType::Float32: + return ::tt::tt_metal::DataType::FLOAT32; + case tt::DataType::BFloat16: + return ::tt::tt_metal::DataType::BFLOAT16; + case tt::DataType::BFP_BFloat8: + return ::tt::tt_metal::DataType::BFLOAT8_B; + case tt::DataType::BFP_BFloat4: + return ::tt::tt_metal::DataType::BFLOAT4_B; + case tt::DataType::UInt32: + return ::tt::tt_metal::DataType::UINT32; + case tt::DataType::UInt16: + return ::tt::tt_metal::DataType::UINT16; + case tt::DataType::UInt8: + return ::tt::tt_metal::DataType::UINT8; + default: + throw std::runtime_error("Invalid element type"); + } +} + +::ttnn::SimpleShape getSimpleShape(const ::llvm::ArrayRef shape) { + ::tt::tt_metal::SmallVector small_vector_shape; + for (const auto &dim : shape) { + small_vector_shape.push_back(static_cast(dim)); + } + + return ::ttnn::SimpleShape(small_vector_shape); +} + +const std::array +getShardShape(const mlir::tt::ttnn::TTNNLayoutAttr &layout) { + const auto layoutShardTile = layout.getScalarShardShape(); + + if (layoutShardTile.size() != 2) { + llvm::errs() << "ERROR: layout_shard_tile.size() != 2\n"; + return {0, 0}; + } + + std::array shardShape; + shardShape[0] = layoutShardTile[0]; + shardShape[1] = layoutShardTile[1]; + return shardShape; +} + +::tt::tt_metal::Layout +getPageLayout(const mlir::tt::ttnn::TTNNLayoutAttr &layout) { + return layout.isTiled() ? ::tt::tt_metal::Layout::TILE + : ::tt::tt_metal::Layout::ROW_MAJOR; +} + +::tt::tt_metal::CoreRangeSet +getCoreRangeSet(const mlir::tt::ttnn::TTNNLayoutAttr &layout) { + std::set<::tt::tt_metal::CoreRange> coreRangeSet; + assert(layout.getGrid().getMapping().isEmpty() == false); + for (const auto &[loc, size] : utils::toCoreRangeSet( + layout.getGrid().getShape(), layout.getGrid().getMapping())) { + coreRangeSet.insert(::tt::tt_metal::CoreRange( + CoreCoord(loc[0], loc[1]), + CoreCoord(loc[0] + size[0] - 1, loc[1] + size[1] - 1))); + } + return ::tt::tt_metal::CoreRangeSet(coreRangeSet); +} + +std::optional<::tt::tt_metal::ShardSpec> +getShardSpec(const mlir::tt::ttnn::TTNNLayoutAttr &layout) { + // tt_ShardOrientation is not part of ttnn::TTNNLayoutAttr; + // defaulting to ROW_MAJOR. TODO(jserbedzija): with issue #620 + return isShardedMemoryLayout(layout.getMemLayout().getValue()) + ? std::make_optional(ShardSpec(getCoreRangeSet(layout), + getShardShape(layout), + ShardOrientation::ROW_MAJOR, false)) + : std::nullopt; +} + +::tt::tt_metal::BufferType +getBufferType(const mlir::tt::ttnn::TTNNLayoutAttr &layout) { + auto bufferType = layout.getBufferType(); + + switch (bufferType) { + case mlir::tt::ttnn::BufferType::DRAM: + return ::tt::tt_metal::BufferType::DRAM; + case mlir::tt::ttnn::BufferType::L1: + return ::tt::tt_metal::BufferType::L1; + case mlir::tt::ttnn::BufferType::SystemMemory: + return ::tt::tt_metal::BufferType::SYSTEM_MEMORY; + case mlir::tt::ttnn::BufferType::L1Small: + return ::tt::tt_metal::BufferType::L1_SMALL; + case mlir::tt::ttnn::BufferType::Trace: + return ::tt::tt_metal::BufferType::TRACE; + } +} + +::tt::tt_metal::TensorMemoryLayout +getTensorMemoryLayout(const mlir::tt::ttnn::TTNNLayoutAttr &layout) { + auto tensorMemoryLayout = layout.getMemLayout().getValue(); + + switch (tensorMemoryLayout) { + case mlir::tt::ttnn::TensorMemoryLayout::Interleaved: + return ::tt::tt_metal::TensorMemoryLayout::INTERLEAVED; + case mlir::tt::ttnn::TensorMemoryLayout::SingleBank: + return ::tt::tt_metal::TensorMemoryLayout::SINGLE_BANK; + case mlir::tt::ttnn::TensorMemoryLayout::HeightSharded: + return ::tt::tt_metal::TensorMemoryLayout::HEIGHT_SHARDED; + case mlir::tt::ttnn::TensorMemoryLayout::WidthSharded: + return ::tt::tt_metal::TensorMemoryLayout::WIDTH_SHARDED; + case mlir::tt::ttnn::TensorMemoryLayout::BlockSharded: + return ::tt::tt_metal::TensorMemoryLayout::BLOCK_SHARDED; + } +} + +::tt::tt_metal::MemoryConfig +getMemoryConfig(const mlir::tt::ttnn::TTNNLayoutAttr &layout) { + + auto tensorMemoryLayout = getTensorMemoryLayout(layout); + auto bufferType = getBufferType(layout); + + auto shardSpec = getShardSpec(layout); + return ::tt::tt_metal::MemoryConfig(tensorMemoryLayout, bufferType, + shardSpec); +} + +::tt::tt_metal::TensorLayout +getTensorLayout(const mlir::tt::ttnn::TTNNLayoutAttr &layout) { + return ::tt::tt_metal::TensorLayout( + getDataType(layout), getPageLayout(layout), getMemoryConfig(layout)); +} + +::ttnn::TensorSpec getTensorSpec(const ::llvm::ArrayRef shape, + const mlir::tt::ttnn::TTNNLayoutAttr &layout) { + return ::ttnn::TensorSpec(getSimpleShape(shape), getTensorLayout(layout)); +} + +} // namespace conversion +} // namespace mlir::tt::op_model::ttnn +#endif // TTMLIR_ENABLE_OPMODEL diff --git a/lib/OpModel/TTNN/Conversion.hpp b/lib/OpModel/TTNN/Conversion.hpp new file mode 100644 index 000000000..137bc180d --- /dev/null +++ b/lib/OpModel/TTNN/Conversion.hpp @@ -0,0 +1,49 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#ifdef TTMLIR_ENABLE_OPMODEL +#include "MetalHeaders.h" + +#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" + +#include "llvm/ADT/ArrayRef.h" + +namespace mlir::tt::op_model::ttnn { +namespace conversion { +::tt::tt_metal::DataType +getDataType(const mlir::tt::ttnn::TTNNLayoutAttr layout); + +::ttnn::SimpleShape getSimpleShape(const ::llvm::ArrayRef shape); + +const std::array +getShardShape(const mlir::tt::ttnn::TTNNLayoutAttr &layout); + +::tt::tt_metal::Layout +getPageLayout(const mlir::tt::ttnn::TTNNLayoutAttr &layout); + +::tt::tt_metal::CoreRangeSet +getCoreRangeSet(const mlir::tt::ttnn::TTNNLayoutAttr &layout); + +std::optional<::tt::tt_metal::ShardSpec> +getShardSpec(const mlir::tt::ttnn::TTNNLayoutAttr &layout); + +::tt::tt_metal::BufferType +getBufferType(const mlir::tt::ttnn::TTNNLayoutAttr &layout); + +::tt::tt_metal::TensorMemoryLayout +getTensorMemoryLayout(const mlir::tt::ttnn::TTNNLayoutAttr &layout); + +::tt::tt_metal::MemoryConfig +getMemoryConfig(const mlir::tt::ttnn::TTNNLayoutAttr &layout); + +::tt::tt_metal::TensorLayout +getTensorLayout(const mlir::tt::ttnn::TTNNLayoutAttr &layout); + +::ttnn::TensorSpec getTensorSpec(const ::llvm::ArrayRef shape, + const mlir::tt::ttnn::TTNNLayoutAttr &layout); + +} // namespace conversion +} // namespace mlir::tt::op_model::ttnn + +#endif // TTMLIR_ENABLE_OPMODEL diff --git a/lib/OpModel/TTNN/TTNNOpModelLib_Impl.h b/lib/OpModel/TTNN/MetalHeaders.h similarity index 77% rename from lib/OpModel/TTNN/TTNNOpModelLib_Impl.h rename to lib/OpModel/TTNN/MetalHeaders.h index ed39d881a..d064a92b0 100644 --- a/lib/OpModel/TTNN/TTNNOpModelLib_Impl.h +++ b/lib/OpModel/TTNN/MetalHeaders.h @@ -2,8 +2,8 @@ // // SPDX-License-Identifier: Apache-2.0 -#ifndef TTMLIR_OPMODEL_TTNN_TTNNOPMODELLIB_IMPL_H -#define TTMLIR_OPMODEL_TTNN_TTNNOPMODELLIB_IMPL_H +#ifndef TTMLIR_OPMODEL_TTNN_METALHEADERS_H +#define TTMLIR_OPMODEL_TTNN_METALHEADERS_H // This header resolves tt-metal warnings that would otherwise be treated as // errors in the MLIR build. Ensure that this is the only place where tt-metal @@ -47,14 +47,27 @@ #pragma clang diagnostic ignored "-Wc++11-narrowing" #pragma clang diagnostic ignored "-Wzero-length-array" #pragma clang diagnostic ignored "-Wdeprecated-declarations" +#pragma clang diagnostic ignored "-Werror,-Wctad-maybe-unsupported" #define FMT_HEADER_ONLY +#include "host_api.hpp" +#include "impl/buffers/buffer_constants.hpp" #include "tt_metal/common/core_coord.hpp" #include "tt_metal/impl/buffers/buffer.hpp" +#include "tt_metal/impl/device/device.hpp" +#include "ttnn/graph/graph_processor.hpp" +#include "ttnn/graph/graph_query_op_constraints.hpp" +#include "ttnn/graph/graph_trace_utils.hpp" +#include "ttnn/operations/eltwise/binary/binary.hpp" +#include "ttnn/operations/eltwise/unary/unary.hpp" +#include "ttnn/operations/matmul/matmul.hpp" +#include "ttnn/operations/normalization/softmax/softmax.hpp" +#include "ttnn/tensor/shape/small_vector.hpp" #include "ttnn/tensor/tensor.hpp" +#include "ttnn/tensor/tensor_spec.hpp" #include "ttnn/tensor/types.hpp" #pragma clang diagnostic pop -#endif // TTMLIR_OPMODEL_TTNN_TTNNOPMODELLIB_IMPL_H +#endif // TTMLIR_OPMODEL_TTNN_METALHEADERS_H diff --git a/lib/OpModel/TTNN/SingletonDeviceContext.cpp b/lib/OpModel/TTNN/SingletonDeviceContext.cpp new file mode 100644 index 000000000..c9c940fa0 --- /dev/null +++ b/lib/OpModel/TTNN/SingletonDeviceContext.cpp @@ -0,0 +1,26 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#ifdef TTMLIR_ENABLE_OPMODEL +#include "SingletonDeviceContext.h" + +#include "MetalHeaders.h" + +namespace mlir::tt::op_model::ttnn { + +SingletonDeviceContext::SingletonDeviceContext() { + m_device = ::tt::tt_metal::CreateDevice(0); +} + +SingletonDeviceContext::~SingletonDeviceContext() { + ::tt::tt_metal::CloseDevice(m_device); +} + +SingletonDeviceContext &SingletonDeviceContext::getInstance() { + static SingletonDeviceContext instance; + return instance; +} + +} // namespace mlir::tt::op_model::ttnn +#endif // TTMLIR_ENABLE_OPMODEL diff --git a/lib/OpModel/TTNN/SingletonDeviceContext.h b/lib/OpModel/TTNN/SingletonDeviceContext.h new file mode 100644 index 000000000..9af10721b --- /dev/null +++ b/lib/OpModel/TTNN/SingletonDeviceContext.h @@ -0,0 +1,46 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#ifndef TTMLIR_OPMODEL_TTNN_SINGLETONDEVICECONTEXT_H +#define TTMLIR_OPMODEL_TTNN_SINGLETONDEVICECONTEXT_H +#ifdef TTMLIR_ENABLE_OPMODEL + +#include + +namespace tt { +namespace tt_metal { +inline namespace v0 { +class Device; +} // namespace v0 +} // namespace tt_metal +} // namespace tt + +namespace mlir::tt::op_model::ttnn { + +// Singleton class to manage the device context, ensuring the device remains +// active while compiler is running multiple graph traces without real +// allocations and op dispatching. + +// TODO (mbezulj): enforce mockup/simulation device when it's enabled in +// tt-metal. + +class SingletonDeviceContext { +public: + static SingletonDeviceContext &getInstance(); + + ::tt::tt_metal::v0::Device *getDevice() { return m_device; } + +private: + SingletonDeviceContext(); + ~SingletonDeviceContext(); + + SingletonDeviceContext(const SingletonDeviceContext &) = delete; + SingletonDeviceContext &operator=(const SingletonDeviceContext &) = delete; + + ::tt::tt_metal::Device *m_device; +}; +} // namespace mlir::tt::op_model::ttnn + +#endif // TTMLIR_ENABLE_OPMODEL +#endif // TTMLIR_OPMODEL_TTNN_SINGLETONDEVICECONTEXT_H diff --git a/lib/OpModel/TTNN/TTNNOpModelLib.cpp b/lib/OpModel/TTNN/TTNNOpModelLib.cpp index 87bfc0415..908579223 100644 --- a/lib/OpModel/TTNN/TTNNOpModelLib.cpp +++ b/lib/OpModel/TTNN/TTNNOpModelLib.cpp @@ -5,178 +5,345 @@ #include "TTNNOpModel.h" #ifdef TTMLIR_ENABLE_OPMODEL -#include "TTNNOpModelLib_Impl.h" +#include "Conversion.hpp" +#include "MetalHeaders.h" +#include "SingletonDeviceContext.h" + #include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" +#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h" +#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" -#include -#include +#include "mlir/IR/AttrTypeSubElements.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Types.h" +#include "llvm/Support/Casting.h" #include +#include #include #endif // TTMLIR_ENABLE_OPMODEL namespace mlir::tt::op_model::ttnn { #ifdef TTMLIR_ENABLE_OPMODEL -// alias to a common tt_metal types -using DataType = ::tt::tt_metal::DataType; -using Layout = ::tt::tt_metal::Layout; -using CoreRange = ::tt::tt_metal::CoreRange; -using CoreRangeSet = ::tt::tt_metal::CoreRangeSet; -using CoreCoord = ::tt::tt_metal::CoreCoord; -using ShardSpec = ::tt::tt_metal::ShardSpec; -using ShardOrientation = ::tt::tt_metal::ShardOrientation; -using TensorMemoryLayout = ::tt::tt_metal::TensorMemoryLayout; -using MemoryConfig = ::tt::tt_metal::MemoryConfig; - -namespace detail { - -DataType getDataType(const mlir::MemRefType &memref) { - - auto dataType = elementTypeToDataType(memref.getElementType()); - - switch (dataType) { - case tt::DataType::Float32: - return DataType::FLOAT32; - case tt::DataType::BFloat16: - return DataType::BFLOAT16; - case tt::DataType::BFP_BFloat8: - return DataType::BFLOAT8_B; - case tt::DataType::BFP_BFloat4: - return DataType::BFLOAT4_B; - case tt::DataType::UInt32: - return DataType::UINT32; - case tt::DataType::UInt16: - return DataType::UINT16; - case tt::DataType::UInt8: - return DataType::UINT8; - default: - throw std::runtime_error("Invalid element type"); +namespace operation { + +/** + * @brief Retrieves operation constraints based on the provided operation name + * and callable. + * + * This function attempts to query operation constraints using the provided + * callable and arguments. It returns a tuple containing a boolean indicating + * success or failure, an optional tuple with resource usage details (if + * successful), and an optional error message (if failed). + * + * @param name The name of the operation to query constraints for. + * @param callable A callable object that performs the query. + * @param args Additional arguments to be forwarded to the callable. + * @return A tuple containing query results. + */ +template +std::tuple>, + std::optional> +getOpConstraints(const std::string_view &name, Callable &callable, + auto &&...args) { + ::ttnn::graph::QueryResponse query; + try { + query = callable(std::forward(args)...); + } catch (const std::exception &e) { + query.status = ::ttnn::graph::ExecutionStatus::Error; + query.error_message = e.what(); } -} - -::ttnn::SimpleShape getTensorShape(const mlir::MemRefType &memref) { - ::tt::tt_metal::SmallVector small_vector_shape( - memref.getShape().begin(), memref.getShape().end()); - return ::ttnn::SimpleShape(small_vector_shape); -} - -const std::array -getShardShape(const mlir::tt::ttnn::TTNNLayoutAttr &layout) { - const auto layoutShardTile = layout.getShardShape(); - if (layoutShardTile.size() != 2) { - llvm::errs() << "ERROR: layout_shard_tile.size() != 2\n"; - return {0, 0}; + // check if query was successful + if (query.status != ::ttnn::graph::ExecutionStatus::Success) { + return std::make_tuple( + false, std::nullopt, + query.error_message.value_or("")); } - std::array shardShape; - shardShape[0] = layoutShardTile[0]; - shardShape[1] = layoutShardTile[1]; - return shardShape; -} - -Layout getTensorLayout(const mlir::tt::ttnn::TTNNLayoutAttr &layout) { - return layout.isTiled() ? Layout::TILE : Layout::ROW_MAJOR; + return std::make_tuple( + true, + std::make_tuple(query.resource_usage.cb_peak_size_per_core, + query.resource_usage.l1_buffers_peak_per_core, + query.resource_usage.l1_output_buffer_per_core), + std::nullopt); } +} // namespace operation -CoreRangeSet getCoreRangeSet(const mlir::tt::ttnn::TTNNLayoutAttr &layout) { - // TODO(mbezulj): handle more complex grid shapes - // assuming grid shape is one rect starting at (0,0) - - const auto layoutGrid = layout.getGrid(); - - const auto layoutGridShape = layoutGrid.getShape(); - if (layoutGridShape.size() != 2) { - llvm::errs() << "ERROR: layout_grid.getShape().size() == 2\n"; - return {}; - } - - return CoreRangeSet(CoreRange(CoreCoord(0, layoutGridShape[0]), - CoreCoord(0, layoutGridShape[1]))); -} - -std::optional -layout_get_shard_spec(const mlir::tt::ttnn::TTNNLayoutAttr &layout) { - // tt_ShardOrientation is not part of ttnn::TTNNLayoutAttr; - // defaulting to ROW_MAJOR. TODO: figure out if we need to expose this - return isShardedMemoryLayout(layout.getMemLayout()) - ? std::make_optional(ShardSpec(getCoreRangeSet(layout), - getShardShape(layout), - ShardOrientation::ROW_MAJOR, false)) - : std::nullopt; -} +namespace detail { -::tt::tt_metal::BufferType getBufferType(const mlir::MemRefType &memref) { - auto memorySpace = - mlir::cast(memref.getMemorySpace()).getValue(); - - switch (memorySpace) { - case tt::MemorySpace::DeviceDRAM: - return ::tt::tt_metal::BufferType::DRAM; - case tt::MemorySpace::DeviceL1: - return ::tt::tt_metal::BufferType::L1; - default: // TODO(mbezulj): handle other memory spaces - throw std::runtime_error("Unsupported memory space"); +/** + * @brief Checks if the shard bounding box fits within the available grid size. + * + * This function verifies whether the shard bounding box specified in the + * memory configuration fits within the range of device worker cores. If the + * memory configuration is sharded and the shard bounding box exceeds the + * available grid size, it throws a runtime error. + * + * @param computeGridSize The compute grid size. + * @param memoryConfig The memory configuration which may specify a shard. + * + * @throws std::runtime_error If the shard bounding box is larger than the + * available grid size. + */ +void checkGrid(const ::tt::tt_metal::CoreCoord &computeGridSize, + const ::tt::tt_metal::MemoryConfig &memoryConfig) { + if (memoryConfig.is_sharded()) { + ::tt::tt_metal::CoreRange shardBoundingBox = + memoryConfig.shard_spec.value().grid.bounding_box(); + ::tt::tt_metal::CoreRangeSet deviceWorkerCores{::tt::tt_metal::CoreRange{ + ::tt::tt_metal::CoreCoord{0, 0}, + ::tt::tt_metal::CoreCoord{computeGridSize.x - 1, + computeGridSize.y - 1}}}; + if (deviceWorkerCores.contains(shardBoundingBox) == false) { + throw std::runtime_error( + "Selected shard is larger than available grid " + "size. Compute Grid Size: " + + computeGridSize.str() + + ", selected bounding box: " + shardBoundingBox.str()); + } } } -::tt::tt_metal::TensorMemoryLayout -getTensorMemoryLayout(const mlir::tt::ttnn::TTNNLayoutAttr &layout) { - auto tensorMemoryLayout = layout.getMemLayout(); - - switch (tensorMemoryLayout) { - case mlir::tt::ttnn::TensorMemoryLayout::Interleaved: - return ::tt::tt_metal::TensorMemoryLayout::INTERLEAVED; - case mlir::tt::ttnn::TensorMemoryLayout::SingleBank: - return ::tt::tt_metal::TensorMemoryLayout::SINGLE_BANK; - case mlir::tt::ttnn::TensorMemoryLayout::HeightSharded: - return ::tt::tt_metal::TensorMemoryLayout::HEIGHT_SHARDED; - case mlir::tt::ttnn::TensorMemoryLayout::WidthSharded: - return ::tt::tt_metal::TensorMemoryLayout::WIDTH_SHARDED; - case mlir::tt::ttnn::TensorMemoryLayout::BlockSharded: - return ::tt::tt_metal::TensorMemoryLayout::BLOCK_SHARDED; - default: - throw std::runtime_error("Unsupported tensor memory layout"); +/** + * @brief Checks the validity of the compute grid size. + * + * This function verifies the dimensions and properties of the provided compute + * grid size. + * + * @param computeGridSize The size of the compute grid, represented as a + * CoreCoord object. + * @param workerGrid The worker grid attributes, represented as a GridAttr + * object. The shape of the worker grid is expected to be in the format {y, x}. + * + * @throws std::runtime_error If the worker grid size does not match the compute + * grid size. + */ +void checkGrid(const ::tt::tt_metal::CoreCoord &computeGridSize, + const mlir::tt::GridAttr &workerGrid) { + // metal CoreCoord holds x,y + // GridAttr holds shape {y,x} + if ((static_cast(workerGrid.getShape()[1]) != computeGridSize.x) || + (static_cast(workerGrid.getShape()[0]) != computeGridSize.y)) { + throw std::runtime_error("Selected worker grid is different than available " + "grid size. Compute Grid Size: " + + computeGridSize.str() + ", Worker Grid Size: (x=" + + std::to_string(workerGrid.getShape()[1]) + ",y=" + + std::to_string(workerGrid.getShape()[0]) + ")"); } } +} // namespace detail +#endif // TTMLIR_ENABLE_OPMODEL -::tt::tt_metal::MemoryConfig -getMemoryConfig(const mlir::tt::ttnn::TTNNLayoutAttr &layout) { - - auto tensorMemoryLayout = getTensorMemoryLayout(layout); - auto bufferType = getBufferType(layout.getMemref()); +//===----------------------------------------------------------------------===// +// Device +//===----------------------------------------------------------------------===// - auto shardSpec = layout_get_shard_spec(layout); - return ::tt::tt_metal::MemoryConfig(tensorMemoryLayout, bufferType, - shardSpec); +std::tuple> +Device::getDeviceConstraints(const mlir::tt::GridAttr &workerGrid) { +#ifdef TTMLIR_ENABLE_OPMODEL + try { + detail::checkGrid(SingletonDeviceContext::getInstance() + .getDevice() + ->compute_with_storage_grid_size(), + workerGrid); + } catch (const std::exception &e) { + return std::make_tuple(false, e.what()); + } +#endif + return std::make_tuple(true, std::nullopt); } -} // namespace detail -#endif // TTMLIR_ENABLE_OPMODEL - //===----------------------------------------------------------------------===// // ReluOp //===----------------------------------------------------------------------===// - -bool ReluOpInterface::isLegal( +std::tuple>, + std::optional> +ReluOpInterface::getOpConstraints( + const ::llvm::ArrayRef &inputShape, const mlir::tt::ttnn::TTNNLayoutAttr &inputLayout, + const ::llvm::ArrayRef &outputShape, const mlir::tt::ttnn::TTNNLayoutAttr &outputLayout) { +#ifdef TTMLIR_ENABLE_OPMODEL + auto reluOpQuery = [](const ::llvm::ArrayRef &inputShape, + const ::mlir::tt::ttnn::TTNNLayoutAttr &inputLayout, + const ::llvm::ArrayRef &outputShape, + const ::mlir::tt::ttnn::TTNNLayoutAttr &outputLayout) { + // open device device, will close it at the end of function + ::tt::tt_metal::v0::Device *device = + SingletonDeviceContext::getInstance().getDevice(); + + // prepare io specs + const ::ttnn::TensorSpec input_spec = + conversion::getTensorSpec(inputShape, inputLayout); + detail::checkGrid(device->compute_with_storage_grid_size(), + input_spec.memory_config()); + const ::ttnn::TensorSpec output_spec = + conversion::getTensorSpec(outputShape, outputLayout); + detail::checkGrid(device->compute_with_storage_grid_size(), + output_spec.memory_config()); + + // run op constraint query + return ::ttnn::graph::query_op_constraints( + ::ttnn::relu, device, input_spec, + output_spec.tensor_layout().get_memory_config()); + }; + + return operation::getOpConstraints("ReluOpInterface", reluOpQuery, inputShape, + inputLayout, outputShape, outputLayout); +#else + return std::make_tuple(true, std::make_tuple(0, 0, 0), std::nullopt); +#endif // TTMLIR_ENABLE_OPMODEL +} +//===----------------------------------------------------------------------===// +// AddOp +//===----------------------------------------------------------------------===// +std::tuple>, + std::optional> +AddOpInterface::getOpConstraints( + const ::llvm::ArrayRef &inputShape_a, + const ::mlir::tt::ttnn::TTNNLayoutAttr &inputLayout_a, + const ::llvm::ArrayRef &inputShape_b, + const ::mlir::tt::ttnn::TTNNLayoutAttr &inputLayout_b, + const ::llvm::ArrayRef &outputShape, + const ::mlir::tt::ttnn::TTNNLayoutAttr &outputLayout) { #ifdef TTMLIR_ENABLE_OPMODEL - return true; // to wire into tt-metal with the next uplift + auto addOpQuery = [](const ::llvm::ArrayRef &inputShape_a, + const ::mlir::tt::ttnn::TTNNLayoutAttr &inputLayout_a, + const ::llvm::ArrayRef &inputShape_b, + const ::mlir::tt::ttnn::TTNNLayoutAttr &inputLayout_b, + const ::llvm::ArrayRef &outputShape, + const ::mlir::tt::ttnn::TTNNLayoutAttr &outputLayout) { + // open device device, will close it at the end of function + ::tt::tt_metal::v0::Device *device = + SingletonDeviceContext::getInstance().getDevice(); + + // prepare io specs + const ::ttnn::TensorSpec input_spec_a = + conversion::getTensorSpec(inputShape_a, inputLayout_a); + detail::checkGrid(device->compute_with_storage_grid_size(), + input_spec_a.memory_config()); + const ::ttnn::TensorSpec input_spec_b = + conversion::getTensorSpec(inputShape_b, inputLayout_b); + detail::checkGrid(device->compute_with_storage_grid_size(), + input_spec_b.memory_config()); + const ::ttnn::TensorSpec output_spec = + conversion::getTensorSpec(outputShape, outputLayout); + detail::checkGrid(device->compute_with_storage_grid_size(), + output_spec.memory_config()); + + return ::ttnn::graph::query_op_constraints( + ::ttnn::add, device, input_spec_a, input_spec_b, + output_spec.data_type(), + output_spec.tensor_layout().get_memory_config()); + }; + + return operation::getOpConstraints("AddOpInterface", addOpQuery, inputShape_a, + inputLayout_a, inputShape_b, inputLayout_b, + outputShape, outputLayout); #else - return true; + return std::make_tuple(true, std::make_tuple(0, 0, 0), std::nullopt); #endif // TTMLIR_ENABLE_OPMODEL } -std::tuple ReluOpInterface::getOpL1Usage( - const mlir::tt::ttnn::TTNNLayoutAttr &inputLayout, +//===----------------------------------------------------------------------===// +// SoftmaxOp +//===----------------------------------------------------------------------===// +std::tuple>, + std::optional> +SoftmaxOpInterface::getOpConstraints( + const llvm::ArrayRef &inputShape, + const mlir::tt::ttnn::TTNNLayoutAttr &inputLayout, const int dim_arg, + const llvm::ArrayRef &outputShape, const mlir::tt::ttnn::TTNNLayoutAttr &outputLayout) { #ifdef TTMLIR_ENABLE_OPMODEL - return std::make_tuple(0, 0, 0); // to wire into tt-metal with the next uplift + auto softmaxOpQuery = [](const llvm::ArrayRef &inputShape, + const mlir::tt::ttnn::TTNNLayoutAttr &inputLayout, + const int dim_arg, + const llvm::ArrayRef &outputShape, + const mlir::tt::ttnn::TTNNLayoutAttr &outputLayout) { + // open device device, will close it at the end of function + ::tt::tt_metal::v0::Device *device = + SingletonDeviceContext::getInstance().getDevice(); + + // prepare io specs + const ::ttnn::TensorSpec input_spec = + conversion::getTensorSpec(inputShape, inputLayout); + detail::checkGrid(device->compute_with_storage_grid_size(), + input_spec.memory_config()); + const ::ttnn::TensorSpec output_spec = + conversion::getTensorSpec(outputShape, outputLayout); + detail::checkGrid(device->compute_with_storage_grid_size(), + output_spec.memory_config()); + + // run op constraint query + return ::ttnn::graph::query_op_constraints( + ::ttnn::softmax, device, input_spec, dim_arg, + output_spec.tensor_layout().get_memory_config()); + }; + + return operation::getOpConstraints("SoftmaxOpInterface", softmaxOpQuery, + inputShape, inputLayout, dim_arg, + outputShape, outputLayout); +#else + return std::make_tuple(true, std::make_tuple(0, 0, 0), std::nullopt); +#endif // TTMLIR_ENABLE_OPMODEL +} + +//===----------------------------------------------------------------------===// +// MatmulOp +//===----------------------------------------------------------------------===// +std::tuple>, + std::optional> +MatmulOpInterface::getOpConstraints( + const llvm::ArrayRef &inputShape_a, + const mlir::tt::ttnn::TTNNLayoutAttr &inputLayout_a, + const llvm::ArrayRef &inputShape_b, + const mlir::tt::ttnn::TTNNLayoutAttr &inputLayout_b, + const llvm::ArrayRef &outputShape, + const mlir::tt::ttnn::TTNNLayoutAttr &outputLayout, bool transpose_a, + bool transpose_b) { +#ifdef TTMLIR_ENABLE_OPMODEL + auto matmulOpQuery = [](const llvm::ArrayRef &inputShape_a, + const mlir::tt::ttnn::TTNNLayoutAttr &inputLayout_a, + const llvm::ArrayRef &inputShape_b, + const mlir::tt::ttnn::TTNNLayoutAttr &inputLayout_b, + const llvm::ArrayRef &outputShape, + const mlir::tt::ttnn::TTNNLayoutAttr &outputLayout, + bool transpose_a, bool transpose_b) { + // open device device, will close it at the end of function + ::tt::tt_metal::v0::Device *device = + SingletonDeviceContext::getInstance().getDevice(); + + // prepare io specs + const ::ttnn::TensorSpec input_spec_a = + conversion::getTensorSpec(inputShape_a, inputLayout_a); + detail::checkGrid(device->compute_with_storage_grid_size(), + input_spec_a.memory_config()); + const ::ttnn::TensorSpec input_spec_b = + conversion::getTensorSpec(inputShape_b, inputLayout_b); + detail::checkGrid(device->compute_with_storage_grid_size(), + input_spec_b.memory_config()); + const ::ttnn::TensorSpec output_spec = + conversion::getTensorSpec(outputShape, outputLayout); + detail::checkGrid(device->compute_with_storage_grid_size(), + output_spec.memory_config()); + + // run op constraint query + return ::ttnn::graph::query_op_constraints( + ::ttnn::matmul, device, input_spec_a, input_spec_b, transpose_a, + transpose_b, output_spec.tensor_layout().get_memory_config(), + output_spec.data_type()); + }; + + return operation::getOpConstraints("MatmulOpInterface", matmulOpQuery, + inputShape_a, inputLayout_a, inputShape_b, + inputLayout_b, outputShape, outputLayout, + transpose_a, transpose_b); #else - return std::make_tuple(0, 0, 0); + return std::make_tuple(true, std::make_tuple(0, 0, 0), std::nullopt); #endif // TTMLIR_ENABLE_OPMODEL } diff --git a/test/lit.cfg.py b/test/lit.cfg.py index d65acc7b2..74204a8f3 100644 --- a/test/lit.cfg.py +++ b/test/lit.cfg.py @@ -98,3 +98,21 @@ def set_system_desc_features(system_desc): ], append_path=True, ) + +if "TT_MLIR_HOME" in os.environ: + print(f"{os.environ['TT_MLIR_HOME']}") + llvm_config.with_environment("TT_MLIR_HOME", os.environ["TT_MLIR_HOME"]) +else: + raise OSError("Error: TT_MLIR_HOME not set") + +if "TT_METAL_HOME" in os.environ: + print(f"{os.environ['TT_METAL_HOME']}") + llvm_config.with_environment("TT_METAL_HOME", os.environ["TT_METAL_HOME"]) +else: + raise OSError("Error: TT_METAL_HOME not set") + +if "ARCH_NAME" in os.environ: + print(f"ARCH_NAME={os.environ['ARCH_NAME']}") + llvm_config.with_environment("ARCH_NAME", os.environ["ARCH_NAME"]) +else: + raise OSError("Error: ARCH_NAME not set.") diff --git a/test/ttmlir/Dialect/TTNN/optimizer/insert_memreconfig_override.mlir b/test/ttmlir/Dialect/TTNN/optimizer/insert_memreconfig_override.mlir index 6989e765f..eb9cc51f8 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/insert_memreconfig_override.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/insert_memreconfig_override.mlir @@ -3,7 +3,7 @@ module attributes {} { func.func @main(%arg0: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0), %arg1: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0), %arg2: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0)) -> tensor<1x32x32xf32> { // CHECK: #[[L1_:.*]] = #ttnn.buffer_type - // CHECK-DAG: #[[LAYOUT_1:.*]] = #ttnn.ttnn_layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), <1x1>, memref<1x1x!tt.tile<32x32, f32>, #l1_>, > + // CHECK-DAG: #[[LAYOUT_1:.*]] = #ttnn.ttnn_layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), <1x1, (d0, d1) -> (0, d1 floordiv 8, d1 mod 8)>, memref<1x1x!tt.tile<32x32, f32>, #l1_>, > // CHECK-DAG: #[[LAYOUT_2:.*]] = #ttnn.ttnn_layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), <1x1>, memref<32x32xf32, #dram>, > %0 = tensor.empty() : tensor<1x32x32xf32> loc(#loc5) // CHECK: %{{.*}} = "ttnn.add"{{.*}} -> tensor<1x32x32xf32, #[[LAYOUT_2]]> diff --git a/test/ttmlir/Dialect/TTNN/optimizer/test_override_reshard_edges.mlir b/test/ttmlir/Dialect/TTNN/optimizer/test_override_reshard_edges.mlir index 08e6da116..4a375fa6e 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/test_override_reshard_edges.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/test_override_reshard_edges.mlir @@ -7,7 +7,7 @@ module attributes {tt.device = #device} { func.func @main(%arg0: tensor<1x32x32xf32, #ttnn_layout>, %arg1: tensor<1x32x32xf32, #ttnn_layout>, %arg2: tensor<1x32x32xf32, #ttnn_layout>) -> tensor<1x32x32xf32, #ttnn_layout> { // CHECK: #[[LAYOUT_1:.*]] = #ttnn.ttnn_layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), <1x1>, memref<32x32xf32, #dram>, > - // CHECK: #[[LAYOUT_2:.*]] = #ttnn.ttnn_layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), <1x1>, memref<1x1x!tt.tile<32x32, f32>, #l1_>, > + // CHECK: #[[LAYOUT_2:.*]] = #ttnn.ttnn_layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), <1x1, (d0, d1) -> (0, d1 floordiv 8, d1 mod 8)>, memref<1x1x!tt.tile<32x32, f32>, #l1_>, > // CHECK: #[[LAYOUT_3:.*]] = #ttnn.ttnn_layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), <8x8, (d0, d1) -> (0, d0, d1)>, memref<1x1x!tt.tile<32x32, f32>, #dram>, > %0 = "ttnn.get_device"() <{mesh_shape = #ttnn}> : () -> !tt.device<#device> %1 = "ttnn.to_layout"(%arg0, %0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, <<32x32>>, >}> : (tensor<1x32x32xf32, #ttnn_layout>, !tt.device<#device>) -> tensor<1x32x32xf32, #ttnn_layout1> diff --git a/test/ttmlir/Silicon/TTNN/optimizer/mnist_sharding.mlir b/test/ttmlir/Silicon/TTNN/optimizer/mnist_sharding.mlir index 96798905c..bbe769b7a 100644 --- a/test/ttmlir/Silicon/TTNN/optimizer/mnist_sharding.mlir +++ b/test/ttmlir/Silicon/TTNN/optimizer/mnist_sharding.mlir @@ -1,11 +1,11 @@ -// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path% enable-optimizer=true memory-layout-analysis-enabled=true" %s > %t.mlir -// RUN: FileCheck %s --input-file=%t.mlir -// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path% enable-optimizer=true memory-layout-analysis-enabled=true" -o output_file.mlir %s +// RUN: FileCheck %s --input-file=output_file.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer output_file.mlir > %t.ttnn #loc = loc("MNISTLinear":4294967295:0) module @"tt-forge-graph" attributes {} { func.func @main(%arg0: tensor<1x784xf32> loc("MNISTLinear":4294967295:0), %arg1: tensor<1x10xf32> loc("MNISTLinear":4294967295:0), %arg2: tensor<256x10xf32> loc("MNISTLinear":4294967295:0), %arg3: tensor<1x256xf32> loc("MNISTLinear":4294967295:0), %arg4: tensor<784x256xf32> loc("MNISTLinear":4294967295:0)) -> tensor<1x10xf32> { - // CHECK-DAG: #[[LAYOUT_10:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x8>, memref<1x1x!tt.tile<32x32, f32>, #l1_>, > - // CHECK-DAG: #[[LAYOUT_11:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<1x1x!tt.tile<32x32, f32>, #l1_>, > + // CHECK-DAG: #[[LAYOUT_10:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x8, (d0, d1) -> (0, d1 floordiv 8, d1 mod 8)>, memref<1x1x!tt.tile<32x32, f32>, #l1_>, > + // CHECK-DAG: #[[LAYOUT_11:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1, (d0, d1) -> (0, d1 floordiv 8, d1 mod 8)>, memref<1x1x!tt.tile<32x32, f32>, #l1_>, > %0 = tensor.empty() : tensor<1x256xf32> loc(#loc8) // CHECK: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<1x256xf32, #[[LAYOUT_10]]> %1 = "ttir.matmul"(%arg0, %arg4, %0) : (tensor<1x784xf32>, tensor<784x256xf32>, tensor<1x256xf32>) -> tensor<1x256xf32> loc(#loc8) diff --git a/test/unittests/CMakeLists.txt b/test/unittests/CMakeLists.txt index a66e00a43..577f1969b 100644 --- a/test/unittests/CMakeLists.txt +++ b/test/unittests/CMakeLists.txt @@ -7,3 +7,4 @@ endfunction() add_subdirectory(TestScheduler) add_subdirectory(Optimizer) +add_subdirectory(OpModel) diff --git a/test/unittests/OpModel/CMakeLists.txt b/test/unittests/OpModel/CMakeLists.txt new file mode 100644 index 000000000..9c34667d0 --- /dev/null +++ b/test/unittests/OpModel/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(TTNN) diff --git a/test/unittests/OpModel/TTNN/CMakeLists.txt b/test/unittests/OpModel/TTNN/CMakeLists.txt new file mode 100644 index 000000000..8a0ff2b71 --- /dev/null +++ b/test/unittests/OpModel/TTNN/CMakeLists.txt @@ -0,0 +1,3 @@ +add_subdirectory(Conversion) +add_subdirectory(Lib) +add_subdirectory(Op) diff --git a/test/unittests/OpModel/TTNN/Conversion/CMakeLists.txt b/test/unittests/OpModel/TTNN/Conversion/CMakeLists.txt new file mode 100644 index 000000000..35c9e3a88 --- /dev/null +++ b/test/unittests/OpModel/TTNN/Conversion/CMakeLists.txt @@ -0,0 +1,25 @@ +if (TTMLIR_ENABLE_OPMODEL) +set(CMAKE_CXX_STANDARD 20) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +# TestConversion is used to test MLIR to TTNN types conversion +add_executable(TestConversion + TestConversion.cpp +) + +target_include_directories(TestConversion + PUBLIC + ${PROJECT_SOURCE_DIR}/lib/OpModel/TTNN/ + ${PROJECT_SOURCE_DIR}/test/unittests/OpModel/TTNN/ +) + +target_link_libraries(TestConversion + PRIVATE + gtest + gtest_main + TTNNOpModelLib + MLIRTTDialect + MLIRTTIRDialect + MLIRTTNNDialect +) +endif() diff --git a/test/unittests/OpModel/TTNN/Conversion/TestConversion.cpp b/test/unittests/OpModel/TTNN/Conversion/TestConversion.cpp new file mode 100644 index 000000000..fd5dd07a0 --- /dev/null +++ b/test/unittests/OpModel/TTNN/Conversion/TestConversion.cpp @@ -0,0 +1,530 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "Conversion.hpp" +#include "OpModelFixture.h" + +#include "llvm/ADT/SmallVector.h" +#include "gtest/gtest.h" + +class MlirToTtnnConversion : public OpModelFixture {}; + +//================================================================================ +// getDataType +//================================================================================ +class MlirToTtnnConversionSimpleShape + : public MlirToTtnnConversion, + public testing::WithParamInterface> {}; + +TEST_P(MlirToTtnnConversionSimpleShape, SimpleShape) { + const auto &tensorShape = GetParam(); + const auto &shape = + mlir::tt::op_model::ttnn::conversion::getSimpleShape(tensorShape); + + EXPECT_EQ(shape.size(), tensorShape.size()); + for (size_t i = 0; i < shape.size(); ++i) { + EXPECT_EQ(shape[i], tensorShape[i]); + } +} + +INSTANTIATE_TEST_SUITE_P( + ToSimpleShape, MlirToTtnnConversionSimpleShape, + ::testing::Values(mlir::SmallVector{64, 32}, + mlir::SmallVector{64, 32, 128}, + mlir::SmallVector{64, 32, 128, 256})); + +//================================================================================ +// getSimpleShape +//================================================================================ +class MlirToTtnnConversionDataType + : public MlirToTtnnConversion, + public testing::WithParamInterface< + std::tuple> {}; + +TEST_P(MlirToTtnnConversionDataType, DataType) { + const auto &dataType = std::get<0>(GetParam()); + const auto &expectedDataType = std::get<1>(GetParam()); + + llvm::SmallVector tensorShape = {32, 32}; + auto layout = mlir::tt::ttnn::TTNNLayoutAttr::get( + &context, tensorShape, + mlir::tt::TileType::get(&context, {32, 32}, dataType), + mlir::tt::ttnn::BufferType::L1, mlir::tt::GridAttr::get(&context, {8, 8}), + mlir::tt::ttnn::TensorMemoryLayoutAttr::get( + &context, mlir::tt::ttnn::TensorMemoryLayout::Interleaved)); + + auto convertedDataType = + mlir::tt::op_model::ttnn::conversion::getDataType(layout); + EXPECT_EQ(convertedDataType, expectedDataType); +} + +INSTANTIATE_TEST_SUITE_P( + ToDataType, MlirToTtnnConversionDataType, + ::testing::Values(std::make_tuple(mlir::tt::DataType::Float32, + ::tt::tt_metal::DataType::FLOAT32), + std::make_tuple(mlir::tt::DataType::BFloat16, + ::tt::tt_metal::DataType::BFLOAT16), + std::make_tuple(mlir::tt::DataType::BFP_BFloat8, + ::tt::tt_metal::DataType::BFLOAT8_B), + std::make_tuple(mlir::tt::DataType::BFP_BFloat4, + ::tt::tt_metal::DataType::BFLOAT4_B), + std::make_tuple(mlir::tt::DataType::UInt32, + ::tt::tt_metal::DataType::UINT32), + std::make_tuple(mlir::tt::DataType::UInt16, + ::tt::tt_metal::DataType::UINT16), + std::make_tuple(mlir::tt::DataType::UInt8, + ::tt::tt_metal::DataType::UINT8))); + +//================================================================================ +// getShardShape +//================================================================================ +class MlirToTtnnConversionShardShape + : public MlirToTtnnConversion, + public testing::WithParamInterface, llvm::SmallVector, + mlir::tt::ttnn::BufferType, mlir::tt::ttnn::TensorMemoryLayout>> {}; + +TEST_P(MlirToTtnnConversionShardShape, ShardShape) { + const auto &virtualGrid = std::get<0>(GetParam()); + const auto &tensorShape = std::get<1>(GetParam()); + const auto &bufferType = std::get<2>(GetParam()); + const auto &tensorMemoryLayout = std::get<3>(GetParam()); + + if (tensorMemoryLayout == mlir::tt::ttnn::TensorMemoryLayout::WidthSharded) { + EXPECT_EQ(virtualGrid[0], 1); + } else if (tensorMemoryLayout == + mlir::tt::ttnn::TensorMemoryLayout::HeightSharded) { + EXPECT_EQ(virtualGrid[1], 1); + } + + const auto layout = CreateTiledLayout(tensorShape, bufferType, + tensorMemoryLayout, virtualGrid); + const auto shardShape = + mlir::tt::op_model::ttnn::conversion::getShardShape(layout); + + EXPECT_EQ(shardShape[0], + ttmlir::utils::alignUp(tensorShape[0] / virtualGrid[0], 32L)); + EXPECT_EQ(shardShape[1], + ttmlir::utils::alignUp(tensorShape[1] / virtualGrid[1], 32L)); +} + +INSTANTIATE_TEST_SUITE_P( + ToShardShape, MlirToTtnnConversionShardShape, + ::testing::Values( + std::make_tuple(llvm::SmallVector{8, 8}, + llvm::SmallVector{4096, 2048}, + mlir::tt::ttnn::BufferType::L1, + mlir::tt::ttnn::TensorMemoryLayout::BlockSharded), + std::make_tuple(llvm::SmallVector{6, 6}, + llvm::SmallVector{4096, 2048}, + mlir::tt::ttnn::BufferType::L1, + mlir::tt::ttnn::TensorMemoryLayout::BlockSharded), + std::make_tuple(llvm::SmallVector{64, 1}, + llvm::SmallVector{4096, 2048}, + mlir::tt::ttnn::BufferType::L1, + mlir::tt::ttnn::TensorMemoryLayout::HeightSharded), + std::make_tuple(llvm::SmallVector{36, 1}, + llvm::SmallVector{4096, 2048}, + mlir::tt::ttnn::BufferType::L1, + mlir::tt::ttnn::TensorMemoryLayout::HeightSharded), + std::make_tuple(llvm::SmallVector{1, 64}, + llvm::SmallVector{4096, 2048}, + mlir::tt::ttnn::BufferType::L1, + mlir::tt::ttnn::TensorMemoryLayout::WidthSharded), + std::make_tuple(llvm::SmallVector{1, 36}, + llvm::SmallVector{4096, 2048}, + mlir::tt::ttnn::BufferType::L1, + mlir::tt::ttnn::TensorMemoryLayout::WidthSharded))); + +//================================================================================ +// getPageLayout +//================================================================================ +TEST_F(MlirToTtnnConversion, PageLayout) { + llvm::SmallVector tensorShape = {16 * 64 * 32, 32}; + + mlir::tt::ttnn::TTNNLayoutAttr tiledLayout = + CreateTiledLayout(tensorShape, mlir::tt::ttnn::BufferType::DRAM, + mlir::tt::ttnn::TensorMemoryLayout::Interleaved); + + EXPECT_EQ(mlir::tt::op_model::ttnn::conversion::getPageLayout(tiledLayout), + ::tt::tt_metal::Layout::TILE); + + mlir::tt::ttnn::TTNNLayoutAttr rowLayout = + CreateRowMajorLayout(tensorShape, mlir::tt::ttnn::BufferType::DRAM, + mlir::tt::ttnn::TensorMemoryLayout::Interleaved); + EXPECT_EQ(mlir::tt::op_model::ttnn::conversion::getPageLayout(rowLayout), + ::tt::tt_metal::Layout::ROW_MAJOR); +} + +//================================================================================ +// getCoreRangeSet +//================================================================================ + +class ShardedCoreRangeSet + : public MlirToTtnnConversion, + public testing::WithParamInterface< + std::tuple, // shard shape + llvm::SmallVector, // virtual grid shape + ::tt::tt_metal::CoreRangeSet>> // expected core range set +{}; + +TEST_P(ShardedCoreRangeSet, ShardedCoreRangeSet) { + const auto &tensorMemoryLayout = std::get<0>(GetParam()); + const auto &tensorShape = std::get<1>(GetParam()); + const auto &grid = std::get<2>(GetParam()); + const auto &expectedCoreRangeSet = std::get<3>(GetParam()); + + const auto layout = CreateTiledLayout( + tensorShape, mlir::tt::ttnn::BufferType::L1, tensorMemoryLayout, grid); + + const auto coreRangeSet = + mlir::tt::op_model::ttnn::conversion::getCoreRangeSet(layout); + + EXPECT_EQ(coreRangeSet.size(), expectedCoreRangeSet.size()); + for (const auto &[v, r] : + llvm::zip(coreRangeSet.ranges(), expectedCoreRangeSet.ranges())) { + EXPECT_EQ(v.start_coord, r.start_coord); + EXPECT_EQ(v.end_coord, r.end_coord); + } +} + +INSTANTIATE_TEST_SUITE_P( + ToCoreRangeSet, ShardedCoreRangeSet, + ::testing::Values( + std::make_tuple(mlir::tt::ttnn::TensorMemoryLayout::WidthSharded, + llvm::SmallVector{32, 56 * 32}, + llvm::SmallVector{1, 56}, + CoreRangeSet{ + CoreRange(CoreCoord(0, 0), CoreCoord(7, 6))}), + std::make_tuple(mlir::tt::ttnn::TensorMemoryLayout::WidthSharded, + llvm::SmallVector{32, 13 * 32}, + llvm::SmallVector{1, 13}, + CoreRangeSet{std::set{ + CoreRange(CoreCoord(0, 0), CoreCoord(7, 0)), + CoreRange(CoreCoord(0, 1), CoreCoord(4, 1))}}), + std::make_tuple(mlir::tt::ttnn::TensorMemoryLayout::HeightSharded, + llvm::SmallVector{56 * 32, 32}, + llvm::SmallVector{56, 1}, + CoreRangeSet{ + CoreRange(CoreCoord(0, 0), CoreCoord(7, 6))}), + std::make_tuple(mlir::tt::ttnn::TensorMemoryLayout::HeightSharded, + llvm::SmallVector{13 * 32, 32}, + llvm::SmallVector{13, 1}, + CoreRangeSet{std::set{ + CoreRange(CoreCoord(0, 0), CoreCoord(7, 0)), + CoreRange(CoreCoord(0, 1), CoreCoord(4, 1))}}), + std::make_tuple(mlir::tt::ttnn::TensorMemoryLayout::BlockSharded, + llvm::SmallVector{7 * 32, 8 * 32}, + llvm::SmallVector{7, 8}, + CoreRangeSet{ + CoreRange(CoreCoord(0, 0), CoreCoord(7, 6))}), + std::make_tuple(mlir::tt::ttnn::TensorMemoryLayout::BlockSharded, + llvm::SmallVector{4 * 11 * 32, 8 * 13 * 32}, + llvm::SmallVector{4, 8}, + CoreRangeSet{ + CoreRange(CoreCoord(0, 0), CoreCoord(7, 3))}))); + +//================================================================================ +// getShardSpec +//================================================================================ + +TEST_F(MlirToTtnnConversion, ShardWithInterleaved) { + const llvm::SmallVector tensorShape = {56 * 32, 56 * 32}; + + // dram interleaved + { + const auto layout = + CreateTiledLayout(tensorShape, mlir::tt::ttnn::BufferType::DRAM, + mlir::tt::ttnn::TensorMemoryLayout::Interleaved); + const auto shardSpec = + mlir::tt::op_model::ttnn::conversion::getShardSpec(layout); + EXPECT_EQ(shardSpec.has_value(), false); + } + + // l1 interleaved + { + const auto layout = + CreateTiledLayout(tensorShape, mlir::tt::ttnn::BufferType::L1, + mlir::tt::ttnn::TensorMemoryLayout::Interleaved); + const auto shardSpec = + mlir::tt::op_model::ttnn::conversion::getShardSpec(layout); + EXPECT_EQ(shardSpec.has_value(), false); + } +} + +class ShardSpecFixture + : public MlirToTtnnConversion, + public testing::WithParamInterface< + std::tuple, // tensor shape + llvm::SmallVector, // phy grid shape + llvm::SmallVector>> // expected shard shape +{}; + +TEST_P(ShardSpecFixture, ShardSpec) { + const auto bufferType = std::get<0>(GetParam()); + const auto tensorMemoryLayout = std::get<1>(GetParam()); + const auto tensorShape = std::get<2>(GetParam()); + const auto phyGridShape = std::get<3>(GetParam()); + const auto expected_shard_shape = std::get<4>(GetParam()); + + auto virtualGrid = + GetVirtualGridShape(tensorShape, tensorMemoryLayout, phyGridShape); + + const auto layout = CreateTiledLayout( + tensorShape, bufferType, tensorMemoryLayout, virtualGrid, phyGridShape); + + const auto shardSpec = + mlir::tt::op_model::ttnn::conversion::getShardSpec(layout); + + EXPECT_EQ(shardSpec.has_value(), true); + + EXPECT_EQ(shardSpec->shape[0], expected_shard_shape[0]); + EXPECT_EQ(shardSpec->shape[1], expected_shard_shape[1]); + EXPECT_EQ(shardSpec->grid.size(), 1); + + EXPECT_EQ(shardSpec->grid.ranges()[0].start_coord, CoreCoord(0, 0)); + EXPECT_EQ(shardSpec->grid.ranges()[0].end_coord, + CoreCoord(phyGridShape[1] - 1, phyGridShape[0] - 1)); + // These fields are not utilized on the compiler + // side, we are setting them to default values. + // Purpose of testing them is to update the test + // if the compiler side changes. + EXPECT_EQ(shardSpec->orientation, ShardOrientation::ROW_MAJOR); + EXPECT_EQ(shardSpec->halo, false); + EXPECT_EQ(shardSpec->mode, ShardMode::PHYSICAL); + EXPECT_EQ(shardSpec->physical_shard_shape.has_value(), false); +} + +INSTANTIATE_TEST_SUITE_P( + ToShardSpec, ShardSpecFixture, + ::testing::Values( + std::make_tuple(mlir::tt::ttnn::BufferType::L1, + mlir::tt::ttnn::TensorMemoryLayout::BlockSharded, + llvm::SmallVector{4096, 2048}, + llvm::SmallVector{8, 8}, + llvm::SmallVector{512, 256}), + std::make_tuple(mlir::tt::ttnn::BufferType::L1, + mlir::tt::ttnn::TensorMemoryLayout::BlockSharded, + llvm::SmallVector{7 * 512, 8 * 256}, + llvm::SmallVector{7, 8}, + llvm::SmallVector{512, 256}), + std::make_tuple(mlir::tt::ttnn::BufferType::L1, + mlir::tt::ttnn::TensorMemoryLayout::BlockSharded, + llvm::SmallVector{4096, 2048}, + llvm::SmallVector{4, 4}, + llvm::SmallVector{1024, 512}), + std::make_tuple(mlir::tt::ttnn::BufferType::L1, + mlir::tt::ttnn::TensorMemoryLayout::HeightSharded, + llvm::SmallVector{4096, 2048}, + llvm::SmallVector{8, 8}, + llvm::SmallVector{64, 2048}), + std::make_tuple(mlir::tt::ttnn::BufferType::L1, + mlir::tt::ttnn::TensorMemoryLayout::HeightSharded, + llvm::SmallVector{56 * 64, 2048}, + llvm::SmallVector{7, 8}, + llvm::SmallVector{64, 2048}), + std::make_tuple(mlir::tt::ttnn::BufferType::L1, + mlir::tt::ttnn::TensorMemoryLayout::HeightSharded, + llvm::SmallVector{4096, 2048}, + llvm::SmallVector{4, 4}, + llvm::SmallVector{256, 2048}), + std::make_tuple(mlir::tt::ttnn::BufferType::L1, + mlir::tt::ttnn::TensorMemoryLayout::WidthSharded, + llvm::SmallVector{4096, 2048}, + llvm::SmallVector{8, 8}, + llvm::SmallVector{4096, 32}), + std::make_tuple(mlir::tt::ttnn::BufferType::L1, + mlir::tt::ttnn::TensorMemoryLayout::WidthSharded, + llvm::SmallVector{4096, 56 * 32}, + llvm::SmallVector{7, 8}, + llvm::SmallVector{4096, 32}), + std::make_tuple(mlir::tt::ttnn::BufferType::L1, + mlir::tt::ttnn::TensorMemoryLayout::WidthSharded, + llvm::SmallVector{4096, 2048}, + llvm::SmallVector{4, 4}, + llvm::SmallVector{4096, 128}))); + +//================================================================================ +// getBufferType +//================================================================================ +class MlirToTtnnConversionBufferType + : public MlirToTtnnConversion, + public testing::WithParamInterface< + std::tuple> {}; + +TEST_P(MlirToTtnnConversionBufferType, BufferType) { + const auto &mlirBufferType = std::get<0>(GetParam()); + const auto &expectedBufferType = std::get<1>(GetParam()); + + auto layout = + CreateTiledLayout({32, 32}, mlirBufferType, + mlir::tt::ttnn::TensorMemoryLayout::Interleaved); + + const auto bufferType = + mlir::tt::op_model::ttnn::conversion::getBufferType(layout); + EXPECT_EQ(bufferType, expectedBufferType); +} + +INSTANTIATE_TEST_SUITE_P( + ToBufferType, MlirToTtnnConversionBufferType, + ::testing::Values(std::make_tuple(mlir::tt::ttnn::BufferType::L1, + tt::tt_metal::BufferType::L1), + std::make_tuple(mlir::tt::ttnn::BufferType::DRAM, + tt::tt_metal::BufferType::DRAM), + std::make_tuple(mlir::tt::ttnn::BufferType::SystemMemory, + tt::tt_metal::BufferType::SYSTEM_MEMORY), + std::make_tuple(mlir::tt::ttnn::BufferType::L1Small, + tt::tt_metal::BufferType::L1_SMALL), + std::make_tuple(mlir::tt::ttnn::BufferType::Trace, + tt::tt_metal::BufferType::TRACE))); + +//================================================================================ +// getTensorMemoryLayout +//================================================================================ +class MlirToTnnConversionTensorMemoryLayout + : public MlirToTtnnConversion, + public testing::WithParamInterface< + std::tuple> {}; + +TEST_P(MlirToTnnConversionTensorMemoryLayout, MemoryConfig) { + const auto &mlirTensorMemoryLayout = + std::get(GetParam()); + const auto &expectedTensorMemoryLayout = + std::get(GetParam()); + + const llvm::SmallVector tensorShape = {56 * 32, 56 * 32}; + + auto layout = CreateTiledLayout(tensorShape, mlir::tt::ttnn::BufferType::L1, + mlirTensorMemoryLayout); + + const auto tensorMemoryLayout = + mlir::tt::op_model::ttnn::conversion::getTensorMemoryLayout(layout); + EXPECT_EQ(tensorMemoryLayout, expectedTensorMemoryLayout); +} + +INSTANTIATE_TEST_SUITE_P( + ToTensorMemoryLayout, MlirToTnnConversionTensorMemoryLayout, + ::testing::Values( + std::make_tuple(mlir::tt::ttnn::TensorMemoryLayout::Interleaved, + tt::tt_metal::TensorMemoryLayout::INTERLEAVED), + std::make_tuple(mlir::tt::ttnn::TensorMemoryLayout::SingleBank, + tt::tt_metal::TensorMemoryLayout::SINGLE_BANK), + std::make_tuple(mlir::tt::ttnn::TensorMemoryLayout::HeightSharded, + tt::tt_metal::TensorMemoryLayout::HEIGHT_SHARDED), + std::make_tuple(mlir::tt::ttnn::TensorMemoryLayout::WidthSharded, + tt::tt_metal::TensorMemoryLayout::WIDTH_SHARDED), + std::make_tuple(mlir::tt::ttnn::TensorMemoryLayout::BlockSharded, + tt::tt_metal::TensorMemoryLayout::BLOCK_SHARDED))); + +//================================================================================ +// getMemoryConfig +//================================================================================ +class MlirToTtnnConversionMemoryConfig + : public MlirToTtnnConversion, + public testing::WithParamInterface> {}; + +TEST_P(MlirToTtnnConversionMemoryConfig, MemoryConfig) { + const auto &mlirBufferType = std::get<0>(GetParam()); + const auto &mlirTensorMemoryLayout = std::get<1>(GetParam()); + const llvm::SmallVector tensorShape = {4096, 2048}; + + auto layout = + CreateTiledLayout(tensorShape, mlirBufferType, mlirTensorMemoryLayout); + + const auto memoryConfig = + mlir::tt::op_model::ttnn::conversion::getMemoryConfig(layout); + + EXPECT_EQ(memoryConfig.is_l1(), + mlirBufferType == mlir::tt::ttnn::BufferType::L1); + EXPECT_EQ(memoryConfig.is_dram(), + mlirBufferType == mlir::tt::ttnn::BufferType::DRAM); + EXPECT_EQ(memoryConfig.is_sharded(), + mlirTensorMemoryLayout != + mlir::tt::ttnn::TensorMemoryLayout::Interleaved); +} + +INSTANTIATE_TEST_SUITE_P( + ToMemoryConfig, MlirToTtnnConversionMemoryConfig, + ::testing::Combine( + ::testing::Values(mlir::tt::ttnn::BufferType::DRAM, + mlir::tt::ttnn::BufferType::L1), + ::testing::Values(mlir::tt::ttnn::TensorMemoryLayout::Interleaved, + mlir::tt::ttnn::TensorMemoryLayout::HeightSharded, + mlir::tt::ttnn::TensorMemoryLayout::WidthSharded, + mlir::tt::ttnn::TensorMemoryLayout::BlockSharded))); + +//================================================================================ +// getTensorLayout +//================================================================================ +TEST_F(MlirToTtnnConversion, TensorLayout) { + const llvm::SmallVector tensorShape = {56 * 32, 56 * 32}; + // test tilized layout + { + const auto layout = + CreateTiledLayout(tensorShape, mlir::tt::ttnn::BufferType::L1, + mlir::tt::ttnn::TensorMemoryLayout::BlockSharded); + + const auto tensorLayout = + mlir::tt::op_model::ttnn::conversion::getTensorLayout(layout); + + EXPECT_EQ(tensorLayout.get_data_type(), tt::tt_metal::DataType::BFLOAT16); + EXPECT_EQ(tensorLayout.get_layout(), tt::tt_metal::Layout::TILE); + EXPECT_EQ(tensorLayout.get_memory_config().is_sharded(), true); + } + // test row-major layout + { + const auto layout = + CreateRowMajorLayout(tensorShape, mlir::tt::ttnn::BufferType::L1, + mlir::tt::ttnn::TensorMemoryLayout::BlockSharded); + + const auto tensorLayout = + mlir::tt::op_model::ttnn::conversion::getTensorLayout(layout); + + EXPECT_EQ(tensorLayout.get_data_type(), tt::tt_metal::DataType::BFLOAT16); + EXPECT_EQ(tensorLayout.get_layout(), tt::tt_metal::Layout::ROW_MAJOR); + EXPECT_EQ(tensorLayout.get_memory_config().is_sharded(), true); + } +} + +//================================================================================ +// getTensorSpec +//================================================================================ +TEST_F(MlirToTtnnConversion, TensorSpec) { + const llvm::SmallVector tensorShape = {56 * 32, 56 * 32}; + // test tilized layout + { + const auto layout = + CreateTiledLayout(tensorShape, mlir::tt::ttnn::BufferType::L1, + mlir::tt::ttnn::TensorMemoryLayout::BlockSharded); + const auto ttnnSimpleShape = + mlir::tt::op_model::ttnn::conversion::getSimpleShape(tensorShape); + const auto ttnnLayout = + mlir::tt::op_model::ttnn::conversion::getTensorLayout(layout); + const auto tensorSpec = mlir::tt::op_model::ttnn::conversion::getTensorSpec( + tensorShape, layout); + EXPECT_EQ(tensorSpec.logical_shape().volume(), ttnnSimpleShape.volume()); + EXPECT_EQ(tensorSpec.page_config().get_layout(), + tt::tt_metal::Layout::TILE); + } + // test row-major layout + { + const auto layout = + CreateRowMajorLayout(tensorShape, mlir::tt::ttnn::BufferType::L1, + mlir::tt::ttnn::TensorMemoryLayout::BlockSharded); + const auto ttnnSimpleShape = + mlir::tt::op_model::ttnn::conversion::getSimpleShape(tensorShape); + const auto ttnnLayout = + mlir::tt::op_model::ttnn::conversion::getTensorLayout(layout); + const auto tensorSpec = mlir::tt::op_model::ttnn::conversion::getTensorSpec( + tensorShape, layout); + EXPECT_EQ(tensorSpec.logical_shape().volume(), ttnnSimpleShape.volume()); + EXPECT_EQ(tensorSpec.page_config().get_layout(), + tt::tt_metal::Layout::ROW_MAJOR); + } +} diff --git a/test/unittests/OpModel/TTNN/Lib/CMakeLists.txt b/test/unittests/OpModel/TTNN/Lib/CMakeLists.txt new file mode 100644 index 000000000..f32094b6f --- /dev/null +++ b/test/unittests/OpModel/TTNN/Lib/CMakeLists.txt @@ -0,0 +1,24 @@ +if (TTMLIR_ENABLE_OPMODEL) +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +add_executable(TestOpModelLib +TestOpModelLib.cpp +) + +target_include_directories(TestOpModelLib + PUBLIC + ${PROJECT_SOURCE_DIR}/lib/OpModel/TTNN/ + ${PROJECT_SOURCE_DIR}/test/unittests/OpModel/TTNN/ +) + +target_link_libraries(TestOpModelLib + PRIVATE + gtest + gtest_main + TTNNOpModelLib + MLIRTTDialect + MLIRTTIRDialect + MLIRTTNNDialect +) +endif() diff --git a/test/unittests/OpModel/TTNN/Lib/TestOpModelLib.cpp b/test/unittests/OpModel/TTNN/Lib/TestOpModelLib.cpp new file mode 100644 index 000000000..27e25edf1 --- /dev/null +++ b/test/unittests/OpModel/TTNN/Lib/TestOpModelLib.cpp @@ -0,0 +1,665 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "OpModelFixture.h" + +#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" +#include "ttmlir/OpModel/TTNN/TTNNOpModel.h" + +#include "llvm/ADT/SmallVector.h" +#include "gtest/gtest.h" + +namespace mlir::tt::op_model::ttnn { + +class OpModelTest : public OpModelFixture {}; + +TEST_F(OpModelTest, ReluInterleaved) { + const llvm::SmallVector tensorShape = {workerCoresN300, 1024}; + const auto workerGrid = CreateWorkerGrid(gridShapeHwN300); + const mlir::tt::ttnn::TTNNLayoutAttr inputLayout_dram = + CreateTiledLayout(tensorShape, mlir::tt::ttnn::BufferType::DRAM, + mlir::tt::ttnn::TensorMemoryLayout::Interleaved); + const mlir::tt::ttnn::TTNNLayoutAttr inputLayout_l1 = + CreateTiledLayout(tensorShape, mlir::tt::ttnn::BufferType::L1, + mlir::tt::ttnn::TensorMemoryLayout::Interleaved); + + bool legal = false; + std::optional> l1Usage = std::nullopt; + std::optional errorMsg = ""; + size_t cb_size = 0; + size_t peak_size = 0; + size_t output_size = 0; + + std::tie(legal, errorMsg) = Device::getDeviceConstraints(workerGrid); + EXPECT_TRUE(legal); + + std::tie(legal, l1Usage, errorMsg) = ReluOpInterface::getOpConstraints( + tensorShape, inputLayout_dram, tensorShape, inputLayout_dram); + EXPECT_TRUE(legal); + EXPECT_TRUE(l1Usage.has_value()); + EXPECT_FALSE(errorMsg.has_value()); + std::tie(cb_size, peak_size, output_size) = l1Usage.value(); + EXPECT_EQ(cb_size, 8192); + EXPECT_EQ(output_size, 0); + EXPECT_EQ(peak_size, 0); + + std::tie(legal, l1Usage, errorMsg) = ReluOpInterface::getOpConstraints( + tensorShape, inputLayout_dram, tensorShape, inputLayout_l1); + EXPECT_TRUE(legal); + EXPECT_TRUE(l1Usage.has_value()); + EXPECT_FALSE(errorMsg.has_value()); + std::tie(cb_size, peak_size, output_size) = l1Usage.value(); + EXPECT_EQ(cb_size, 8192); + EXPECT_EQ(output_size, 4096); + EXPECT_EQ(peak_size, 4096); + + std::tie(legal, l1Usage, errorMsg) = ReluOpInterface::getOpConstraints( + tensorShape, inputLayout_l1, tensorShape, inputLayout_dram); + EXPECT_TRUE(legal); + EXPECT_TRUE(l1Usage.has_value()); + EXPECT_FALSE(errorMsg.has_value()); + std::tie(cb_size, peak_size, output_size) = l1Usage.value(); + EXPECT_EQ(cb_size, 8192); + EXPECT_EQ(output_size, 0); + EXPECT_EQ(peak_size, 0); + + std::tie(legal, l1Usage, errorMsg) = ReluOpInterface::getOpConstraints( + tensorShape, inputLayout_l1, tensorShape, inputLayout_l1); + EXPECT_TRUE(legal); + EXPECT_TRUE(l1Usage.has_value()); + EXPECT_FALSE(errorMsg.has_value()); + std::tie(cb_size, peak_size, output_size) = l1Usage.value(); + EXPECT_EQ(cb_size, 8192); + EXPECT_EQ(output_size, 4096); + EXPECT_EQ(peak_size, 4096); +} + +TEST_F(OpModelTest, ReluSharded) { + const llvm::SmallVector tensorShape = {14 * workerCoresN300 * 32, + 32}; + const auto workerGrid = CreateWorkerGrid(gridShapeHwN300); + const mlir::tt::ttnn::TTNNLayoutAttr inputLayout_l1_hs = + CreateTiledLayout(tensorShape, mlir::tt::ttnn::BufferType::L1, + mlir::tt::ttnn::TensorMemoryLayout::HeightSharded); + const mlir::tt::ttnn::TTNNLayoutAttr inputLayout_l1_i = + CreateTiledLayout(tensorShape, mlir::tt::ttnn::BufferType::L1, + mlir::tt::ttnn::TensorMemoryLayout::Interleaved); + + bool legal = false; + std::optional> l1Usage = std::nullopt; + std::optional errorMsg = ""; + size_t cb_size = 0; + size_t peak_size = 0; + size_t output_size = 0; + + std::tie(legal, errorMsg) = Device::getDeviceConstraints(workerGrid); + EXPECT_TRUE(legal); + + std::tie(legal, l1Usage, errorMsg) = ReluOpInterface::getOpConstraints( + tensorShape, inputLayout_l1_hs, tensorShape, inputLayout_l1_hs); + EXPECT_TRUE(legal); + EXPECT_TRUE(l1Usage.has_value()); + EXPECT_FALSE(errorMsg.has_value()); + std::tie(cb_size, peak_size, output_size) = l1Usage.value(); + EXPECT_EQ(cb_size, 0); + EXPECT_EQ(output_size, tensorShape[0] * tensorShape[1] * 2 / workerCoresN300); + EXPECT_EQ(peak_size, tensorShape[0] * tensorShape[1] * 2 / workerCoresN300); + + legal = std::get<0>(ReluOpInterface::getOpConstraints( + tensorShape, inputLayout_l1_hs, tensorShape, inputLayout_l1_i)); + // Unary operation requires Input and Output memory layout to match. + EXPECT_EQ(legal, false); + legal = std::get<0>(ReluOpInterface::getOpConstraints( + tensorShape, inputLayout_l1_i, tensorShape, inputLayout_l1_hs)); + // Unary operation requires Input and Output memory layout to match. + EXPECT_EQ(legal, false); +} + +TEST_F(OpModelTest, SoftmaxInterleaved) { + const llvm::SmallVector tensorShape = {workerCoresN300, 1024}; + const auto workerGrid = CreateWorkerGrid(gridShapeHwN300); + const mlir::tt::ttnn::TTNNLayoutAttr inputLayout_dram = + CreateTiledLayout(tensorShape, mlir::tt::ttnn::BufferType::DRAM, + mlir::tt::ttnn::TensorMemoryLayout::Interleaved); + const mlir::tt::ttnn::TTNNLayoutAttr inputLayout_l1 = + CreateTiledLayout(tensorShape, mlir::tt::ttnn::BufferType::L1, + mlir::tt::ttnn::TensorMemoryLayout::Interleaved); + + bool legal = false; + std::optional> l1Usage = std::nullopt; + std::optional errorMsg = ""; + size_t cb_size = 0; + size_t peak_size = 0; + size_t output_size = 0; + + std::tie(legal, errorMsg) = Device::getDeviceConstraints(workerGrid); + EXPECT_TRUE(legal); + + std::tie(legal, l1Usage, errorMsg) = SoftmaxOpInterface::getOpConstraints( + tensorShape, inputLayout_dram, -1, tensorShape, inputLayout_dram); + EXPECT_EQ(legal, true); + EXPECT_TRUE(l1Usage.has_value()); + EXPECT_FALSE(errorMsg.has_value()); + std::tie(cb_size, peak_size, output_size) = l1Usage.value(); + EXPECT_EQ(cb_size, 137216); + EXPECT_EQ(output_size, 0); + EXPECT_EQ(peak_size, 0); + + std::tie(legal, l1Usage, errorMsg) = SoftmaxOpInterface::getOpConstraints( + tensorShape, inputLayout_dram, -1, tensorShape, inputLayout_l1); + EXPECT_EQ(legal, true); + EXPECT_TRUE(l1Usage.has_value()); + EXPECT_FALSE(errorMsg.has_value()); + std::tie(cb_size, peak_size, output_size) = l1Usage.value(); + EXPECT_EQ(cb_size, 137216); + EXPECT_EQ(output_size, 4096); + EXPECT_EQ(peak_size, 4096); + + std::tie(legal, l1Usage, errorMsg) = SoftmaxOpInterface::getOpConstraints( + tensorShape, inputLayout_l1, -1, tensorShape, inputLayout_dram); + EXPECT_EQ(legal, true); + EXPECT_TRUE(l1Usage.has_value()); + EXPECT_FALSE(errorMsg.has_value()); + std::tie(cb_size, peak_size, output_size) = l1Usage.value(); + EXPECT_EQ(cb_size, 137216); + EXPECT_EQ(output_size, 0); + EXPECT_EQ(peak_size, 0); + + std::tie(legal, l1Usage, errorMsg) = SoftmaxOpInterface::getOpConstraints( + tensorShape, inputLayout_l1, -1, tensorShape, inputLayout_l1); + EXPECT_EQ(legal, true); + EXPECT_TRUE(l1Usage.has_value()); + EXPECT_FALSE(errorMsg.has_value()); + std::tie(cb_size, peak_size, output_size) = l1Usage.value(); + EXPECT_EQ(cb_size, 137216); + EXPECT_EQ(output_size, 4096); + EXPECT_EQ(peak_size, 4096); + + std::tie(legal, l1Usage, errorMsg) = SoftmaxOpInterface::getOpConstraints( + tensorShape, inputLayout_dram, -1, tensorShape, inputLayout_dram); + EXPECT_TRUE(legal); + EXPECT_TRUE(l1Usage.has_value()); + EXPECT_FALSE(errorMsg.has_value()); + std::tie(cb_size, peak_size, output_size) = l1Usage.value(); + EXPECT_EQ(cb_size, 137216); + EXPECT_EQ(output_size, 0); + EXPECT_EQ(peak_size, 0); +} + +TEST_F(OpModelTest, SoftmaxSharded) { + const llvm::SmallVector tensorShape = {16 * workerCoresN300 * 32, + 32}; + const auto workerGrid = CreateWorkerGrid(gridShapeHwN300); + const mlir::tt::ttnn::TTNNLayoutAttr inputLayout_l1_hs = + CreateTiledLayout(tensorShape, mlir::tt::ttnn::BufferType::L1, + mlir::tt::ttnn::TensorMemoryLayout::HeightSharded); + const mlir::tt::ttnn::TTNNLayoutAttr inputLayout_l1_i = + CreateTiledLayout(tensorShape, mlir::tt::ttnn::BufferType::L1, + mlir::tt::ttnn::TensorMemoryLayout::Interleaved); + + bool legal = false; + std::optional> l1Usage = std::nullopt; + std::optional errorMsg = ""; + size_t cb_size = 0; + size_t peak_size = 0; + size_t output_size = 0; + + std::tie(legal, errorMsg) = Device::getDeviceConstraints(workerGrid); + EXPECT_TRUE(legal); + + std::tie(legal, l1Usage, errorMsg) = SoftmaxOpInterface::getOpConstraints( + tensorShape, inputLayout_l1_hs, -2, tensorShape, inputLayout_l1_hs); + EXPECT_TRUE(legal); + EXPECT_TRUE(l1Usage.has_value()); + EXPECT_FALSE(errorMsg.has_value()); + std::tie(cb_size, peak_size, output_size) = l1Usage.value(); + EXPECT_EQ(cb_size, 24576); + EXPECT_EQ(output_size, 32768); + EXPECT_EQ(peak_size, 32768); + + std::tie(legal, l1Usage, errorMsg) = SoftmaxOpInterface::getOpConstraints( + tensorShape, inputLayout_l1_hs, -2, tensorShape, inputLayout_l1_i); + EXPECT_TRUE(legal); + EXPECT_TRUE(l1Usage.has_value()); + EXPECT_FALSE(errorMsg.has_value()); + std::tie(cb_size, peak_size, output_size) = l1Usage.value(); + EXPECT_EQ(cb_size, 24576); + EXPECT_EQ(output_size, 32768); + EXPECT_EQ(peak_size, 32768); + + std::tie(legal, l1Usage, errorMsg) = SoftmaxOpInterface::getOpConstraints( + tensorShape, inputLayout_l1_i, -2, tensorShape, inputLayout_l1_hs); + EXPECT_TRUE(legal); + EXPECT_TRUE(l1Usage.has_value()); + EXPECT_FALSE(errorMsg.has_value()); + std::tie(cb_size, peak_size, output_size) = l1Usage.value(); + EXPECT_EQ(cb_size, 24576); + EXPECT_EQ(output_size, 32768); + EXPECT_EQ(peak_size, 32768); +} + +TEST_F(OpModelTest, AddInterleaved) { + const llvm::SmallVector tensorShape = {workerCoresN300, 1024}; + const auto workerGrid = CreateWorkerGrid(gridShapeHwN300); + const mlir::tt::ttnn::TTNNLayoutAttr inputLayout_dram = + CreateTiledLayout(tensorShape, mlir::tt::ttnn::BufferType::DRAM, + mlir::tt::ttnn::TensorMemoryLayout::Interleaved); + const mlir::tt::ttnn::TTNNLayoutAttr inputLayout_l1 = + CreateTiledLayout(tensorShape, mlir::tt::ttnn::BufferType::L1, + mlir::tt::ttnn::TensorMemoryLayout::Interleaved); + + bool legal = false; + std::optional> l1Usage = std::nullopt; + std::optional errorMsg = ""; + size_t cb_size = 0; + size_t peak_size = 0; + size_t output_size = 0; + + std::tie(legal, errorMsg) = Device::getDeviceConstraints(workerGrid); + EXPECT_TRUE(legal); + + std::tie(legal, l1Usage, errorMsg) = AddOpInterface::getOpConstraints( + tensorShape, inputLayout_dram, tensorShape, inputLayout_dram, tensorShape, + inputLayout_dram); + EXPECT_TRUE(legal); + EXPECT_TRUE(l1Usage.has_value()); + EXPECT_FALSE(errorMsg.has_value()); + std::tie(cb_size, peak_size, output_size) = l1Usage.value(); + EXPECT_EQ(cb_size, 12288); + EXPECT_EQ(peak_size, 0); + EXPECT_EQ(output_size, 0); + + std::tie(legal, l1Usage, errorMsg) = AddOpInterface::getOpConstraints( + tensorShape, inputLayout_dram, tensorShape, inputLayout_dram, tensorShape, + inputLayout_l1); + EXPECT_TRUE(legal); + EXPECT_TRUE(l1Usage.has_value()); + EXPECT_FALSE(errorMsg.has_value()); + std::tie(cb_size, peak_size, output_size) = l1Usage.value(); + EXPECT_EQ(cb_size, 12288); + EXPECT_EQ(peak_size, 4096); + EXPECT_EQ(output_size, 4096); + + std::tie(legal, l1Usage, errorMsg) = AddOpInterface::getOpConstraints( + tensorShape, inputLayout_dram, tensorShape, inputLayout_l1, tensorShape, + inputLayout_dram); + EXPECT_TRUE(legal); + EXPECT_TRUE(l1Usage.has_value()); + EXPECT_FALSE(errorMsg.has_value()); + std::tie(cb_size, peak_size, output_size) = l1Usage.value(); + EXPECT_EQ(cb_size, 12288); + EXPECT_EQ(peak_size, 0); + EXPECT_EQ(output_size, 0); + + std::tie(legal, l1Usage, errorMsg) = AddOpInterface::getOpConstraints( + tensorShape, inputLayout_dram, tensorShape, inputLayout_l1, tensorShape, + inputLayout_l1); + EXPECT_TRUE(legal); + EXPECT_TRUE(l1Usage.has_value()); + EXPECT_FALSE(errorMsg.has_value()); + std::tie(cb_size, peak_size, output_size) = l1Usage.value(); + EXPECT_EQ(cb_size, 12288); + EXPECT_EQ(peak_size, 4096); + EXPECT_EQ(output_size, 4096); + + std::tie(legal, l1Usage, errorMsg) = AddOpInterface::getOpConstraints( + tensorShape, inputLayout_l1, tensorShape, inputLayout_dram, tensorShape, + inputLayout_dram); + EXPECT_TRUE(legal); + EXPECT_TRUE(l1Usage.has_value()); + EXPECT_FALSE(errorMsg.has_value()); + std::tie(cb_size, peak_size, output_size) = l1Usage.value(); + EXPECT_EQ(cb_size, 12288); + EXPECT_EQ(peak_size, 0); + EXPECT_EQ(output_size, 0); + + std::tie(legal, l1Usage, errorMsg) = AddOpInterface::getOpConstraints( + tensorShape, inputLayout_l1, tensorShape, inputLayout_dram, tensorShape, + inputLayout_l1); + EXPECT_TRUE(legal); + EXPECT_TRUE(l1Usage.has_value()); + EXPECT_FALSE(errorMsg.has_value()); + std::tie(cb_size, peak_size, output_size) = l1Usage.value(); + EXPECT_EQ(cb_size, 12288); + EXPECT_EQ(peak_size, 4096); + EXPECT_EQ(output_size, 4096); + + std::tie(legal, l1Usage, errorMsg) = AddOpInterface::getOpConstraints( + tensorShape, inputLayout_l1, tensorShape, inputLayout_l1, tensorShape, + inputLayout_dram); + EXPECT_TRUE(legal); + EXPECT_TRUE(l1Usage.has_value()); + EXPECT_FALSE(errorMsg.has_value()); + std::tie(cb_size, peak_size, output_size) = l1Usage.value(); + EXPECT_EQ(cb_size, 12288); + EXPECT_EQ(peak_size, 0); + EXPECT_EQ(output_size, 0); + + std::tie(legal, l1Usage, errorMsg) = AddOpInterface::getOpConstraints( + tensorShape, inputLayout_l1, tensorShape, inputLayout_l1, tensorShape, + inputLayout_l1); + EXPECT_TRUE(legal); + EXPECT_TRUE(l1Usage.has_value()); + EXPECT_FALSE(errorMsg.has_value()); + std::tie(cb_size, peak_size, output_size) = l1Usage.value(); + EXPECT_EQ(cb_size, 12288); + EXPECT_EQ(peak_size, 4096); + EXPECT_EQ(output_size, 4096); + + std::tie(legal, l1Usage, errorMsg) = AddOpInterface::getOpConstraints( + tensorShape, inputLayout_dram, tensorShape, inputLayout_dram, tensorShape, + inputLayout_dram); + EXPECT_TRUE(legal); + EXPECT_TRUE(l1Usage.has_value()); + EXPECT_FALSE(errorMsg.has_value()); + std::tie(cb_size, peak_size, output_size) = l1Usage.value(); + EXPECT_EQ(cb_size, 12288); + EXPECT_EQ(output_size, 0); + EXPECT_EQ(peak_size, 0); +} + +TEST_F(OpModelTest, AddSharded) { + const llvm::SmallVector tensorShape = {16 * workerCoresN300 * 32, + 32}; + const auto workerGrid = CreateWorkerGrid(gridShapeHwN300); + const mlir::tt::ttnn::TTNNLayoutAttr inputLayout_l1_hs = + CreateTiledLayout(tensorShape, mlir::tt::ttnn::BufferType::L1, + mlir::tt::ttnn::TensorMemoryLayout::HeightSharded, + llvm::SmallVector{8, 1}); + const mlir::tt::ttnn::TTNNLayoutAttr inputLayout_dram = + CreateTiledLayout(tensorShape, mlir::tt::ttnn::BufferType::DRAM, + mlir::tt::ttnn::TensorMemoryLayout::Interleaved); + + bool legal = false; + std::optional> l1Usage = std::nullopt; + std::optional errorMsg = ""; + size_t cb_size = 0; + size_t peak_size = 0; + size_t output_size = 0; + + std::tie(legal, errorMsg) = Device::getDeviceConstraints(workerGrid); + EXPECT_TRUE(legal); + + std::tie(legal, l1Usage, errorMsg) = AddOpInterface::getOpConstraints( + tensorShape, inputLayout_l1_hs, tensorShape, inputLayout_dram, + tensorShape, inputLayout_l1_hs); + EXPECT_TRUE(legal); + EXPECT_TRUE(l1Usage.has_value()); + EXPECT_FALSE(errorMsg.has_value()); + std::tie(cb_size, peak_size, output_size) = l1Usage.value(); + EXPECT_EQ(cb_size, 32768); + EXPECT_EQ(peak_size, 229376); + EXPECT_EQ(output_size, 229376); + + std::tie(legal, l1Usage, errorMsg) = AddOpInterface::getOpConstraints( + tensorShape, inputLayout_l1_hs, tensorShape, inputLayout_dram, + tensorShape, inputLayout_dram); + EXPECT_TRUE(legal); + EXPECT_TRUE(l1Usage.has_value()); + EXPECT_FALSE(errorMsg.has_value()); + std::tie(cb_size, peak_size, output_size) = l1Usage.value(); + EXPECT_EQ(cb_size, 65536); + EXPECT_EQ(peak_size, 0); + EXPECT_EQ(output_size, 0); + + std::tie(legal, l1Usage, errorMsg) = AddOpInterface::getOpConstraints( + tensorShape, inputLayout_dram, tensorShape, inputLayout_dram, tensorShape, + inputLayout_l1_hs); + EXPECT_TRUE(legal); + EXPECT_TRUE(l1Usage.has_value()); + EXPECT_FALSE(errorMsg.has_value()); + std::tie(cb_size, peak_size, output_size) = l1Usage.value(); + EXPECT_EQ(cb_size, 65536); + EXPECT_EQ(peak_size, 229376); + EXPECT_EQ(output_size, 229376); +} + +class OpModelMatmulParam + : public OpModelTest, + public testing::WithParamInterface< + std::tuple, // input shape A + mlir::tt::ttnn::TensorMemoryLayout, // input layout A + mlir::tt::ttnn::BufferType, // input buffer type A + llvm::SmallVector, // input virtual grid A + llvm::SmallVector, // input shape B + mlir::tt::ttnn::TensorMemoryLayout, // input layout B + mlir::tt::ttnn::BufferType, // input buffer type B + llvm::SmallVector, // input virtual grid B + llvm::SmallVector, // output shape + mlir::tt::ttnn::TensorMemoryLayout, // output layout + mlir::tt::ttnn::BufferType, // output buffer type + llvm::SmallVector, // output virtual grid + llvm::SmallVector, // physical grid + bool, // expected valid + size_t, // expected cb size + size_t, // expected peak size + size_t // expected output size + >> {}; + +TEST_P(OpModelMatmulParam, MatmulParam) { + + auto params = GetParam(); + llvm::SmallVector inputShapeA = std::get<0>(params); + mlir::tt::ttnn::TensorMemoryLayout inputTensorLayoutA = std::get<1>(params); + mlir::tt::ttnn::BufferType inputBufferTypeA = std::get<2>(params); + llvm::SmallVector inputVirtualGridA = std::get<3>(params); + llvm::SmallVector inputShapeB = std::get<4>(params); + mlir::tt::ttnn::TensorMemoryLayout inputTensorLayoutB = std::get<5>(params); + mlir::tt::ttnn::BufferType inputBufferTypeB = std::get<6>(params); + llvm::SmallVector inputVirtualGridB = std::get<7>(params); + llvm::SmallVector outputShape = std::get<8>(params); + mlir::tt::ttnn::TensorMemoryLayout outputTensorLayout = std::get<9>(params); + mlir::tt::ttnn::BufferType outputBufferType = std::get<10>(params); + llvm::SmallVector outputVirtualGrid = std::get<11>(params); + llvm::SmallVector physicalGrid = std::get<12>(params); + bool expectedLegal = std::get<13>(params); + size_t expectedCbSize = std::get<14>(params); + size_t expectedPeakSize = std::get<15>(params); + size_t expectedOutputSize = std::get<16>(params); + + const mlir::tt::ttnn::TTNNLayoutAttr inputLayoutA = CreateTiledLayout( + inputShapeA, inputBufferTypeA, inputTensorLayoutA, inputVirtualGridA); + const mlir::tt::ttnn::TTNNLayoutAttr inputLayoutB = CreateTiledLayout( + inputShapeB, inputBufferTypeB, inputTensorLayoutB, inputVirtualGridB); + const mlir::tt::ttnn::TTNNLayoutAttr outputLayout = CreateTiledLayout( + outputShape, outputBufferType, outputTensorLayout, outputVirtualGrid); + + bool legal = false; + std::optional> l1Usage = std::nullopt; + std::optional errorMsg = ""; + size_t cbSize = 0; + size_t peakSize = 0; + size_t outputSize = 0; + + std::tie(legal, l1Usage, errorMsg) = MatmulOpInterface::getOpConstraints( + inputShapeA, inputLayoutA, inputShapeB, inputLayoutB, outputShape, + outputLayout, false, false); + EXPECT_EQ(legal, expectedLegal); + EXPECT_EQ(l1Usage.has_value(), expectedLegal); + EXPECT_EQ(errorMsg.has_value(), !expectedLegal); + + if (l1Usage.has_value()) { + std::tie(cbSize, peakSize, outputSize) = l1Usage.value(); + EXPECT_EQ(cbSize, expectedCbSize); + EXPECT_EQ(peakSize, expectedPeakSize); + EXPECT_EQ(outputSize, expectedOutputSize); + } + + std::cout << errorMsg.value_or("No errors") << std::endl; +} + +INSTANTIATE_TEST_SUITE_P( + MatmulInterleavedTests, OpModelMatmulParam, + ::testing::Values( + std::make_tuple( + llvm::SmallVector{2048, 2048}, + mlir::tt::ttnn::TensorMemoryLayout::Interleaved, + mlir::tt::ttnn::BufferType::DRAM, llvm::SmallVector{8, 8}, + llvm::SmallVector{2048, 2048}, + mlir::tt::ttnn::TensorMemoryLayout::Interleaved, + mlir::tt::ttnn::BufferType::DRAM, llvm::SmallVector{8, 8}, + llvm::SmallVector{2048, 2048}, + mlir::tt::ttnn::TensorMemoryLayout::Interleaved, + mlir::tt::ttnn::BufferType::DRAM, llvm::SmallVector{8, 8}, + llvm::SmallVector{8, 8}, true, 786432, 0, 0), + std::make_tuple( + llvm::SmallVector{2048, 2048}, + mlir::tt::ttnn::TensorMemoryLayout::Interleaved, + mlir::tt::ttnn::BufferType::DRAM, llvm::SmallVector{8, 8}, + llvm::SmallVector{2048, 2048}, + mlir::tt::ttnn::TensorMemoryLayout::Interleaved, + mlir::tt::ttnn::BufferType::DRAM, llvm::SmallVector{8, 8}, + llvm::SmallVector{2048, 2048}, + mlir::tt::ttnn::TensorMemoryLayout::Interleaved, + mlir::tt::ttnn::BufferType::L1, llvm::SmallVector{8, 8}, + llvm::SmallVector{8, 8}, true, 786432, 151552, 151552), + std::make_tuple( + llvm::SmallVector{2048, 2048}, + mlir::tt::ttnn::TensorMemoryLayout::Interleaved, + mlir::tt::ttnn::BufferType::DRAM, llvm::SmallVector{8, 8}, + llvm::SmallVector{2048, 2048}, + mlir::tt::ttnn::TensorMemoryLayout::Interleaved, + mlir::tt::ttnn::BufferType::L1, llvm::SmallVector{8, 8}, + llvm::SmallVector{2048, 2048}, + mlir::tt::ttnn::TensorMemoryLayout::Interleaved, + mlir::tt::ttnn::BufferType::DRAM, llvm::SmallVector{8, 8}, + llvm::SmallVector{8, 8}, true, 786432, 0, 0), + std::make_tuple( + llvm::SmallVector{2048, 2048}, + mlir::tt::ttnn::TensorMemoryLayout::Interleaved, + mlir::tt::ttnn::BufferType::DRAM, llvm::SmallVector{8, 8}, + llvm::SmallVector{2048, 2048}, + mlir::tt::ttnn::TensorMemoryLayout::Interleaved, + mlir::tt::ttnn::BufferType::L1, llvm::SmallVector{8, 8}, + llvm::SmallVector{2048, 2048}, + mlir::tt::ttnn::TensorMemoryLayout::Interleaved, + mlir::tt::ttnn::BufferType::L1, llvm::SmallVector{8, 8}, + llvm::SmallVector{8, 8}, true, 786432, 151552, 151552), + std::make_tuple( + llvm::SmallVector{2048, 2048}, + mlir::tt::ttnn::TensorMemoryLayout::Interleaved, + mlir::tt::ttnn::BufferType::L1, llvm::SmallVector{8, 8}, + llvm::SmallVector{2048, 2048}, + mlir::tt::ttnn::TensorMemoryLayout::Interleaved, + mlir::tt::ttnn::BufferType::DRAM, llvm::SmallVector{8, 8}, + llvm::SmallVector{2048, 2048}, + mlir::tt::ttnn::TensorMemoryLayout::Interleaved, + mlir::tt::ttnn::BufferType::DRAM, llvm::SmallVector{8, 8}, + llvm::SmallVector{8, 8}, true, 786432, 0, 0), + std::make_tuple( + llvm::SmallVector{2048, 2048}, + mlir::tt::ttnn::TensorMemoryLayout::Interleaved, + mlir::tt::ttnn::BufferType::L1, llvm::SmallVector{8, 8}, + llvm::SmallVector{2048, 2048}, + mlir::tt::ttnn::TensorMemoryLayout::Interleaved, + mlir::tt::ttnn::BufferType::DRAM, llvm::SmallVector{8, 8}, + llvm::SmallVector{2048, 2048}, + mlir::tt::ttnn::TensorMemoryLayout::Interleaved, + mlir::tt::ttnn::BufferType::L1, llvm::SmallVector{8, 8}, + llvm::SmallVector{8, 8}, true, 786432, 151552, 151552), + std::make_tuple( + llvm::SmallVector{2048, 2048}, + mlir::tt::ttnn::TensorMemoryLayout::Interleaved, + mlir::tt::ttnn::BufferType::L1, llvm::SmallVector{8, 8}, + llvm::SmallVector{2048, 2048}, + mlir::tt::ttnn::TensorMemoryLayout::Interleaved, + mlir::tt::ttnn::BufferType::L1, llvm::SmallVector{8, 8}, + llvm::SmallVector{2048, 2048}, + mlir::tt::ttnn::TensorMemoryLayout::Interleaved, + mlir::tt::ttnn::BufferType::DRAM, llvm::SmallVector{8, 8}, + llvm::SmallVector{8, 8}, true, 786432, 0, 0), + std::make_tuple( + llvm::SmallVector{2048, 2048}, + mlir::tt::ttnn::TensorMemoryLayout::Interleaved, + mlir::tt::ttnn::BufferType::L1, llvm::SmallVector{8, 8}, + llvm::SmallVector{2048, 2048}, + mlir::tt::ttnn::TensorMemoryLayout::Interleaved, + mlir::tt::ttnn::BufferType::L1, llvm::SmallVector{8, 8}, + llvm::SmallVector{2048, 2048}, + mlir::tt::ttnn::TensorMemoryLayout::Interleaved, + mlir::tt::ttnn::BufferType::L1, llvm::SmallVector{8, 8}, + llvm::SmallVector{8, 8}, true, 786432, 151552, 151552))); + +INSTANTIATE_TEST_SUITE_P( + MatmulShardedTests, OpModelMatmulParam, + ::testing::Values( + std::make_tuple( + llvm::SmallVector{56 * 32, 56 * 32}, + mlir::tt::ttnn::TensorMemoryLayout::BlockSharded, + mlir::tt::ttnn::BufferType::L1, llvm::SmallVector{7, 8}, + llvm::SmallVector{56 * 32, 56 * 32}, + mlir::tt::ttnn::TensorMemoryLayout::Interleaved, + mlir::tt::ttnn::BufferType::DRAM, llvm::SmallVector{7, 8}, + llvm::SmallVector{56 * 32, 56 * 32}, + mlir::tt::ttnn::TensorMemoryLayout::BlockSharded, + mlir::tt::ttnn::BufferType::L1, llvm::SmallVector{7, 8}, + llvm::SmallVector{7, 8}, true, 430144, 114688, 114688), + std::make_tuple( + llvm::SmallVector{56 * 32, 56 * 32}, + mlir::tt::ttnn::TensorMemoryLayout::BlockSharded, + mlir::tt::ttnn::BufferType::L1, llvm::SmallVector{7, 8}, + llvm::SmallVector{56 * 32, 56 * 32}, + mlir::tt::ttnn::TensorMemoryLayout::BlockSharded, + mlir::tt::ttnn::BufferType::L1, llvm::SmallVector{7, 8}, + llvm::SmallVector{56 * 32, 56 * 32}, + mlir::tt::ttnn::TensorMemoryLayout::BlockSharded, + mlir::tt::ttnn::BufferType::L1, llvm::SmallVector{7, 8}, + llvm::SmallVector{7, 8}, false, -1, -1, -1), + std::make_tuple( + llvm::SmallVector{56 * 32, 56 * 32}, + mlir::tt::ttnn::TensorMemoryLayout::Interleaved, + mlir::tt::ttnn::BufferType::DRAM, llvm::SmallVector{7, 8}, + llvm::SmallVector{56 * 32, 56 * 32}, + mlir::tt::ttnn::TensorMemoryLayout::Interleaved, + mlir::tt::ttnn::BufferType::DRAM, llvm::SmallVector{7, 8}, + llvm::SmallVector{56 * 32, 56 * 32}, + mlir::tt::ttnn::TensorMemoryLayout::BlockSharded, + mlir::tt::ttnn::BufferType::L1, llvm::SmallVector{7, 8}, + llvm::SmallVector{7, 8}, true, 262144, 401408, + 401408), // matmul bug shards to less cores + std::make_tuple( + llvm::SmallVector{56 * 32, 56 * 32}, + mlir::tt::ttnn::TensorMemoryLayout::BlockSharded, + mlir::tt::ttnn::BufferType::L1, llvm::SmallVector{7, 8}, + llvm::SmallVector{56 * 32, 56 * 32}, + mlir::tt::ttnn::TensorMemoryLayout::Interleaved, + mlir::tt::ttnn::BufferType::DRAM, llvm::SmallVector{7, 8}, + llvm::SmallVector{56 * 32, 56 * 32}, + mlir::tt::ttnn::TensorMemoryLayout::Interleaved, + mlir::tt::ttnn::BufferType::DRAM, llvm::SmallVector{7, 8}, + llvm::SmallVector{7, 8}, true, 544832, 0, 0), + std::make_tuple( + llvm::SmallVector{56 * 32, 56 * 32}, + mlir::tt::ttnn::TensorMemoryLayout::BlockSharded, + mlir::tt::ttnn::BufferType::L1, llvm::SmallVector{7, 8}, + llvm::SmallVector{56 * 32, 56 * 32}, + mlir::tt::ttnn::TensorMemoryLayout::HeightSharded, + mlir::tt::ttnn::BufferType::L1, llvm::SmallVector{56, 1}, + llvm::SmallVector{56 * 32, 56 * 32}, + mlir::tt::ttnn::TensorMemoryLayout::Interleaved, + mlir::tt::ttnn::BufferType::DRAM, llvm::SmallVector{7, 8}, + llvm::SmallVector{7, 8}, false, -1, -1, -1), + std::make_tuple( + llvm::SmallVector{1 * 32, 56 * 32}, + mlir::tt::ttnn::TensorMemoryLayout::WidthSharded, + mlir::tt::ttnn::BufferType::L1, llvm::SmallVector{1, 56}, + llvm::SmallVector{56 * 32, 56 * 32}, + mlir::tt::ttnn::TensorMemoryLayout::Interleaved, + mlir::tt::ttnn::BufferType::DRAM, llvm::SmallVector{7, 8}, + llvm::SmallVector{1 * 32, 56 * 32}, + mlir::tt::ttnn::TensorMemoryLayout::WidthSharded, + mlir::tt::ttnn::BufferType::L1, llvm::SmallVector{1, 56}, + llvm::SmallVector{7, 8}, true, 8256, 2048, 2048), + std::make_tuple( + llvm::SmallVector{56 * 32, 1 * 32}, + mlir::tt::ttnn::TensorMemoryLayout::HeightSharded, + mlir::tt::ttnn::BufferType::L1, llvm::SmallVector{56, 1}, + llvm::SmallVector{1 * 32, 56 * 32}, + mlir::tt::ttnn::TensorMemoryLayout::Interleaved, + mlir::tt::ttnn::BufferType::DRAM, llvm::SmallVector{7, 8}, + llvm::SmallVector{56 * 32, 56 * 32}, + mlir::tt::ttnn::TensorMemoryLayout::HeightSharded, + mlir::tt::ttnn::BufferType::L1, llvm::SmallVector{56, 1}, + llvm::SmallVector{7, 8}, true, 114688, 114688, 114688))); +} // namespace mlir::tt::op_model::ttnn diff --git a/test/unittests/OpModel/TTNN/Op/CMakeLists.txt b/test/unittests/OpModel/TTNN/Op/CMakeLists.txt new file mode 100644 index 000000000..3a2c89530 --- /dev/null +++ b/test/unittests/OpModel/TTNN/Op/CMakeLists.txt @@ -0,0 +1,24 @@ +if (TTMLIR_ENABLE_OPMODEL) +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +add_executable(TestOpModelInterface +TestOpModelInterface.cpp +) + +target_include_directories(TestOpModelInterface + PUBLIC + ${PROJECT_SOURCE_DIR}/lib/OpModel/TTNN/ + ${PROJECT_SOURCE_DIR}/test/unittests/OpModel/TTNN/ +) + +target_link_libraries(TestOpModelInterface + PRIVATE + gtest + gtest_main + TTNNOpModelLib + MLIRTTDialect + MLIRTTIRDialect + MLIRTTNNDialect +) +endif() diff --git a/test/unittests/OpModel/TTNN/Op/TestOpModelInterface.cpp b/test/unittests/OpModel/TTNN/Op/TestOpModelInterface.cpp new file mode 100644 index 000000000..f40b7351b --- /dev/null +++ b/test/unittests/OpModel/TTNN/Op/TestOpModelInterface.cpp @@ -0,0 +1,206 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "OpModelFixture.h" + +#include "ttmlir/Dialect/TTNN/IR/TTNN.h" +#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h" +#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" + +#include "mlir/IR/AffineExpr.h" +#include "gtest/gtest.h" + +#include + +namespace mlir::tt::ttnn { + +class OpModelBase : public OpModelFixture { +public: + // helper function to extract op data and call into get op constraints + std::optional< + std::tuple>, + std::optional>> + getOpConstraints(Operation *op) { + std::vector inputs; + + // TODO(odjuricic): check for DPS explicitly. + // create input layouts + auto numOperand = op->getNumOperands(); + // some ops have multiple operands + auto limit = (numOperand > 1) ? numOperand - 1 : numOperand; + for (size_t i = 0; i < limit; i++) { + auto operand = op->getOperand(i); + auto inputShape = + mlir::cast(operand.getType()).getShape(); + auto inputLayout = CreateTiledLayout(inputShape, BufferType::L1, + TensorMemoryLayout::Interleaved); + inputs.push_back(inputLayout); + } + + // create output layout + auto output = op->getResult(0); + auto outputShape = + mlir::cast(output.getType()).getShape(); + auto outputLayout = CreateTiledLayout(outputShape, BufferType::L1, + TensorMemoryLayout::Interleaved); + + // call op model interface - getOpConstraints() + if (OpModel backend = dyn_cast(op)) { + auto constraints = backend.getOpConstraints(inputs, outputLayout); + return constraints; + } + return std::nullopt; + } + + mlir::tt::DeviceAttr getFakeDeviceAttr() { + auto deviceIdx = mlir::getAffineConstantExpr(0, &context); + auto shardOffset = mlir::getAffineConstantExpr(0, &context); + auto d0 = mlir::getAffineDimExpr(0, &context); // d0 + auto d1 = mlir::getAffineDimExpr(1, &context); // d1 + auto map3 = mlir::AffineMap::get( + /*dimCount=*/2, /*symbolCount=*/0, {deviceIdx, d0, d1}, &context); + auto map4 = mlir::AffineMap::get( + /*dimCount=*/2, /*symbolCount=*/0, {deviceIdx, d0, d1, shardOffset}, + &context); + auto workerGrid = GridAttr::get(&context, gridShapeHwN300, map3); + + return DeviceAttr::get(&context, workerGrid, map4, map4, {1}, {0}); + } + + mlir::Value createEmptyTensor(llvm::ArrayRef tensorShape) { + Type elementType = builder.getBF16Type(); + RankedTensorType rankedTensorType = + RankedTensorType::get(tensorShape, elementType); + return builder.create(builder.getUnknownLoc(), rankedTensorType, + ShapeAttr::get(&context, tensorShape), + nullptr, nullptr, nullptr, nullptr); + } +}; + +TEST_F(OpModelBase, ReluInterface) { + // create ReluOp + llvm::SmallVector tensorShape = {workerCoresN300, 1024}; + + auto input = createEmptyTensor(tensorShape); + auto output = createEmptyTensor(tensorShape); + + auto relu = builder.create(builder.getUnknownLoc(), output.getType(), + ::mlir::ValueRange{input, output}); + relu->setAttr(DeviceAttr::name, getFakeDeviceAttr()); + + // test ReluOp interface + auto value = getOpConstraints(relu.getOperation()); + if (value.has_value()) { + auto constraints = value.value(); + EXPECT_EQ(std::get(constraints), true); + auto l1 = std::get<1>(constraints); + if (l1.has_value()) { + const auto &[cb_size, peak_size, output_size] = l1.value(); + EXPECT_EQ(cb_size, 8192); + EXPECT_EQ(peak_size, 4096); + EXPECT_EQ(output_size, 4096); + } else { + FAIL() << "Missing L1 constraints; Error=" + << std::get<2>(constraints).value() << std::endl; + } + } else { + FAIL() << "Failed to cast ReluOp to OpModel"; + } +} +TEST_F(OpModelBase, SoftmaxInterface) { + // create SoftmaxOp + llvm::SmallVector tensorShape = {workerCoresN300, 1024}; + + auto input = createEmptyTensor(tensorShape); + auto output = createEmptyTensor(tensorShape); + + auto softmax = builder.create(builder.getUnknownLoc(), + output.getType(), input, -1); + softmax->setAttr(DeviceAttr::name, getFakeDeviceAttr()); + + // test SoftmaxOp interface + auto value = getOpConstraints(softmax.getOperation()); + if (value.has_value()) { + auto constraints = value.value(); + EXPECT_EQ(std::get(constraints), true); + auto l1 = std::get<1>(constraints); + if (l1.has_value()) { + const auto &[cb_size, peak_size, output_size] = l1.value(); + EXPECT_EQ(cb_size, 137216); + EXPECT_EQ(peak_size, 4096); + EXPECT_EQ(output_size, 4096); + } else { + FAIL() << "Missing L1 constraints"; + } + } else { + FAIL() << "Failed to cast ReluOp to OpModel"; + } +} + +TEST_F(OpModelBase, AddInterface) { + // create AddOp + llvm::SmallVector tensorShape = {workerCoresN300, 1024}; + + auto input1 = createEmptyTensor(tensorShape); + auto input2 = createEmptyTensor(tensorShape); + auto output = createEmptyTensor(tensorShape); + + auto add = builder.create(builder.getUnknownLoc(), output.getType(), + ::mlir::ValueRange{input1, input2, output}); + add->setAttr(DeviceAttr::name, getFakeDeviceAttr()); + + // test AddOp interface + auto value = getOpConstraints(add.getOperation()); + if (value.has_value()) { + auto constraints = value.value(); + EXPECT_EQ(std::get(constraints), true); + auto l1 = std::get<1>(constraints); + if (l1.has_value()) { + const auto &[cb_size, peak_size, output_size] = l1.value(); + EXPECT_EQ(cb_size, 12288); + EXPECT_EQ(peak_size, 4096); + EXPECT_EQ(output_size, 4096); + } else { + FAIL() << "Missing L1 constraints"; + } + } else { + FAIL() << "Failed to cast ReluOp to OpModel"; + } +} + +TEST_F(OpModelBase, MatmulInterface) { + // create MatmulOp + llvm::SmallVector tensorShapeA = {2048, 1024}; + llvm::SmallVector tensorShapeB = {1024, 2048}; + llvm::SmallVector tensorShapeO = {2048, 2048}; + + auto inputA = createEmptyTensor(tensorShapeA); + auto inputB = createEmptyTensor(tensorShapeB); + auto output = createEmptyTensor(tensorShapeO); + + auto matmul = + builder.create(builder.getUnknownLoc(), output.getType(), + ::mlir::ValueRange{inputA, inputB, output}); + matmul->setAttr(DeviceAttr::name, getFakeDeviceAttr()); + + // test MatmulOp interface + auto value = getOpConstraints(matmul.getOperation()); + if (value.has_value()) { + auto constraints = value.value(); + EXPECT_EQ(std::get(constraints), true); + auto l1 = std::get<1>(constraints); + if (l1.has_value()) { + const auto &[cb_size, peak_size, output_size] = l1.value(); + EXPECT_EQ(cb_size, 786432); + EXPECT_EQ(peak_size, 151552); + EXPECT_EQ(output_size, 151552); + } else { + FAIL() << "Missing L1 constraints"; + } + } else { + FAIL() << "Failed to cast ReluOp to OpModel"; + } +} + +} // namespace mlir::tt::ttnn diff --git a/test/unittests/OpModel/TTNN/OpModelFixture.h b/test/unittests/OpModel/TTNN/OpModelFixture.h new file mode 100644 index 000000000..a39967335 --- /dev/null +++ b/test/unittests/OpModel/TTNN/OpModelFixture.h @@ -0,0 +1,128 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#ifndef UNITTESTS_OPMODEL_TTNN_OPMODELFIXTURE_H +#define UNITTESTS_OPMODEL_TTNN_OPMODELFIXTURE_H + +#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" +#include "ttmlir/Dialect/TTNN/IR/TTNN.h" +#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" +#include "ttmlir/Dialect/TTNN/Utils/VirtualToPhysicalAffineMap.h" +#include "ttmlir/Utils.h" + +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "gtest/gtest.h" + +#include +#include +#include + +class OpModelFixture : public ::testing::Test { +public: + mlir::MLIRContext context; + mlir::OpBuilder builder = mlir::OpBuilder(&context); + + void SetUp() override { context.loadDialect(); } + + // helper function + llvm::SmallVector + GetTensorShapeInTiles(const llvm::ArrayRef &tensorShape) { + return {ttmlir::utils::alignUp(tensorShape[0], 32L), + ttmlir::utils::alignUp(tensorShape[1], 32L)}; + } + + static llvm::SmallVector GetPhysicalGridSize() { + llvm::SmallVector grid = {gridShapeHwN300[0], gridShapeHwN300[1]}; + return grid; + } + + // helper function to get virtual grid shape based on tensor shape and + // memory + llvm::SmallVector GetVirtualGridShape( + const llvm::ArrayRef &tensorShape, + const mlir::tt::ttnn::TensorMemoryLayout &tensorMemoryLayout, + const llvm::ArrayRef &gridPhyCores = GetPhysicalGridSize()) { + + llvm::SmallVector tensorShapeTiles = + GetTensorShapeInTiles(tensorShape); + + assert(tensorShape.size() == 2); + switch (tensorMemoryLayout) { + case mlir::tt::ttnn::TensorMemoryLayout::WidthSharded: + return {1, std::min(tensorShapeTiles[0] * tensorShapeTiles[1], + gridPhyCores[0] * gridPhyCores[1])}; + case mlir::tt::ttnn::TensorMemoryLayout::HeightSharded: + return {std::min(tensorShapeTiles[0] * tensorShapeTiles[1], + gridPhyCores[0] * gridPhyCores[1]), + 1}; + case mlir::tt::ttnn::TensorMemoryLayout::BlockSharded: + return {std::min(gridPhyCores[0], tensorShapeTiles[0]), + std::min(gridPhyCores[1], tensorShapeTiles[1])}; + default: + return {gridPhyCores[0], gridPhyCores[1]}; + } + } + + mlir::tt::ttnn::TTNNLayoutAttr CreateTiledLayout( + const llvm::ArrayRef &tensorShape, + const mlir::tt::ttnn::BufferType &bufferType, + const mlir::tt::ttnn::TensorMemoryLayout &tensorMemoryLayout, + const std::optional> &virtualGrid = + std::nullopt, + const llvm::SmallVector physicalGrid = GetPhysicalGridSize()) { + const auto &virtualGridSelected = + virtualGrid.has_value() + ? virtualGrid.value() + : GetVirtualGridShape(tensorShape, tensorMemoryLayout); + + return mlir::tt::ttnn::TTNNLayoutAttr::get( + &context, tensorShape, + mlir::tt::TileType::get(&context, builder.getBF16Type()), bufferType, + CreateGrid(&context, tensorMemoryLayout, virtualGridSelected, + physicalGrid), + mlir::tt::ttnn::TensorMemoryLayoutAttr::get(&context, + tensorMemoryLayout)); + } + + mlir::tt::ttnn::TTNNLayoutAttr CreateRowMajorLayout( + const llvm::ArrayRef &tensorShape, + const mlir::tt::ttnn::BufferType &bufferType, + const mlir::tt::ttnn::TensorMemoryLayout &tensorMemoryLayout, + const llvm::ArrayRef &gridShape = GetPhysicalGridSize()) { + return mlir::tt::ttnn::TTNNLayoutAttr::get( + &context, tensorShape, builder.getBF16Type(), bufferType, + CreateGrid(&context, tensorMemoryLayout, + GetVirtualGridShape(tensorShape, tensorMemoryLayout), + GetPhysicalGridSize()), + mlir::tt::ttnn::TensorMemoryLayoutAttr::get(&context, + tensorMemoryLayout)); + } + + mlir::tt::GridAttr + CreateGrid(::mlir::MLIRContext *context, + const mlir::tt::ttnn::TensorMemoryLayout tensorMemoryLayout, + const llvm::ArrayRef virtualGridSize, + const llvm::ArrayRef physicalGridSize) { + + auto affineMap = + mlir::tt::ttnn::utils::CreateSingleDeviceVirtualToPhysicalAffineMap( + context, tensorMemoryLayout, physicalGridSize); + + return mlir::tt::GridAttr::get(context, virtualGridSize, affineMap); + } + + mlir::tt::GridAttr CreateWorkerGrid( + const llvm::ArrayRef physicalGridSize = GetPhysicalGridSize()) { + return mlir::tt::GridAttr::get(&context, physicalGridSize); + } + + static constexpr std::array gridShapeHwN300 = {7, 8}; + static constexpr size_t workerCoresN300 = 56; +}; + +#endif // UNITTESTS_OPMODEL_TTNN_OPMODELFIXTURE_H diff --git a/test/unittests/lit.cfg.py b/test/unittests/lit.cfg.py index 525d2013d..31ca168ba 100644 --- a/test/unittests/lit.cfg.py +++ b/test/unittests/lit.cfg.py @@ -39,3 +39,22 @@ # that causes the tests to fail. if "HOME" in os.environ: config.environment["HOME"] = os.environ["HOME"] + + +if "TT_MLIR_HOME" in os.environ: + print(f"{os.environ['TT_MLIR_HOME']}") + config.environment["TT_MLIR_HOME"] = os.environ["TT_MLIR_HOME"] +else: + raise OSError("TT_MLIR_HOME environment variable is not set") + +if "TT_METAL_HOME" in os.environ: + print(f"{os.environ['TT_METAL_HOME']}") + config.environment["TT_METAL_HOME"] = os.environ["TT_METAL_HOME"] +else: + raise OSError("TT_METAL_HOME environment variable is not set") + +if "ARCH_NAME" in os.environ: + print(f"ARCH_NAME={os.environ['ARCH_NAME']}") + config.environment["ARCH_NAME"] = os.environ["ARCH_NAME"] +else: + raise OSError("ARCH_NAME environment variable is not set")