From c6da33c26676f4c82d3c07981f3d90d8f9b4e99f Mon Sep 17 00:00:00 2001 From: svuckovic Date: Mon, 23 Dec 2024 16:03:56 +0000 Subject: [PATCH] wip --- .github/workflows/build-and-test.yml | 16 +- runtime/include/tt/runtime/detail/ttnn.h | 5 + runtime/include/tt/runtime/runtime.h | 7 + runtime/lib/runtime.cpp | 34 ++++ runtime/lib/ttnn/runtime.cpp | 168 +++++++++++++++++- runtime/tools/python/ttrt/common/run.py | 41 +++++ runtime/tools/python/ttrt/common/util.py | 8 +- runtime/tools/python/ttrt/runtime/__init__.py | 3 + runtime/tools/python/ttrt/runtime/module.cpp | 9 + test/lit.cfg.py | 37 ++++ test/ttmlir/Silicon/TTNN/emitc/mnist.mlir | 9 +- .../ttmlir/Silicon/TTNN/emitc/simple_add.mlir | 5 +- test/ttmlir/Silicon/TTNN/emitc/two_fns.mlir | 5 +- tools/ttnn-standalone/CMakeLists.txt | 7 + tools/ttnn-standalone/compile_dylib.py | 112 ++++++++++++ tools/ttnn-standalone/ttnn-dylib.cpp | 62 +++---- tools/ttnn-standalone/ttnn-dylib.hpp | 3 +- 17 files changed, 487 insertions(+), 44 deletions(-) create mode 100755 tools/ttnn-standalone/compile_dylib.py diff --git a/.github/workflows/build-and-test.yml b/.github/workflows/build-and-test.yml index 68db5d1cf..a4f884223 100644 --- a/.github/workflows/build-and-test.yml +++ b/.github/workflows/build-and-test.yml @@ -265,10 +265,11 @@ jobs: fail-fast: false matrix: build: [ - {runs-on: n150, enable_perf: OFF, name: "run", ttrt_flags: "--non-zero"}, - {runs-on: n150, enable_perf: ON, name: "perf"}, - {runs-on: n300, enable_perf: OFF, name: "run", ttrt_flags: "--non-zero"}, - {runs-on: n300, enable_perf: ON, name: "perf"}, + {runs-on: n150, enable_perf: OFF, emitc: OFF, name: "run", ttrt_flags: "--non-zero"}, + {runs-on: n150, enable_perf: ON, emitc: OFF, name: "perf"}, + {runs-on: n150, enable_perf: OFF, emitc: ON, name: "run", ttrt_flags: "--emitc"}, + {runs-on: n300, enable_perf: OFF, emitc: OFF, name: "run", ttrt_flags: "--non-zero"}, + {runs-on: n300, enable_perf: ON, emitc: OFF, name: "perf"}, ] name: "run-tests (${{ matrix.build.runs-on }}, ${{ matrix.build.enable_perf }}, ${{ matrix.build.name }})" @@ -374,6 +375,13 @@ jobs: ttrt ${{ matrix.build.name }} ${{ matrix.build.ttrt_flags }} ${{ steps.strings.outputs.build-output-dir }}/test/ttmlir/Silicon/TTNN/perf_unit cp ttrt_report.xml ${{ steps.strings.outputs.test_report_path }} + - name: Run emitc tests + shell: bash + if: matrix.build.emitc == 'ON' + run: | + source env/activate + ttrt ${{ matrix.build.name }} ${{ matrix.build.ttrt_flags }} ${{ steps.strings.outputs.build-output-dir }}/test/ttmlir/Silicon/TTNN/emitc + - name: Upload ttrt test report json if: always() uses: actions/upload-artifact@v4 diff --git a/runtime/include/tt/runtime/detail/ttnn.h b/runtime/include/tt/runtime/detail/ttnn.h index 2310789b6..bd07a05b9 100644 --- a/runtime/include/tt/runtime/detail/ttnn.h +++ b/runtime/include/tt/runtime/detail/ttnn.h @@ -130,6 +130,11 @@ std::vector runProgram(::ttnn::MeshDevice &meshDevice, std::uint32_t programIndex, std::vector<::ttnn::Tensor *> const &inputs); +bool compareOuts(std::vector &lhs, std::vector &rhs); + +std::vector do_stuff(void *so, std::string func_name, + std::vector inputs, Device device); + } // namespace tt::runtime::ttnn #endif diff --git a/runtime/include/tt/runtime/runtime.h b/runtime/include/tt/runtime/runtime.h index 2f278ffc1..0687041db 100644 --- a/runtime/include/tt/runtime/runtime.h +++ b/runtime/include/tt/runtime/runtime.h @@ -111,6 +111,13 @@ Event submit(Device deviceHandle, Binary executableHandle, std::uint32_t programIndex, std::vector const &inputs, std::vector const &outputs); +void *openSo(std::string path); + +std::vector runSoProgram(void *so, std::string name, + std::vector inputs, Device device); + +bool compareOuts(std::vector &lhs, std::vector &rhs); + } // namespace tt::runtime #endif diff --git a/runtime/lib/runtime.cpp b/runtime/lib/runtime.cpp index c25cfed51..20227ce4b 100644 --- a/runtime/lib/runtime.cpp +++ b/runtime/lib/runtime.cpp @@ -11,6 +11,8 @@ #if defined(TT_RUNTIME_ENABLE_TTNN) #include "tt/runtime/detail/ttnn.h" #include "tt/runtime/ttnn/types.h" + +#include #endif #if defined(TT_RUNTIME_ENABLE_TTMETAL) @@ -486,4 +488,36 @@ Event submit(Device deviceHandle, Binary executableHandle, #endif LOG_FATAL("runtime is not enabled"); } + +void *openSo(std::string path) { +#if defined(TT_RUNTIME_ENABLE_TTNN) + if (getCurrentRuntime() == DeviceRuntime::TTNN) { + void *handle = dlopen(path.c_str(), RTLD_LAZY); + if (!handle) { + std::cerr << "Failed to load shared object: " << dlerror() << std::endl; + throw std::runtime_error("Failed to load shared object"); + } + + dlerror(); + return handle; + } +#endif + throw std::runtime_error("ttnn runtime is not enabled"); +} + +std::vector runSoProgram(void *so, std::string name, + std::vector inputs, Device device) { + + return ::tt::runtime::ttnn::do_stuff(so, name, inputs, device); +} + +bool compareOuts(std::vector &lhs, std::vector &rhs) { +#if defined(TT_RUNTIME_ENABLE_TTNN) + if (getCurrentRuntime() == DeviceRuntime::TTNN) { + return ::tt::runtime::ttnn::compareOuts(lhs, rhs); + } +#endif + throw std::runtime_error("ttnn runtime is not enabled"); +} + } // namespace tt::runtime diff --git a/runtime/lib/ttnn/runtime.cpp b/runtime/lib/ttnn/runtime.cpp index 3fd7ba1b9..eeee8270d 100644 --- a/runtime/lib/ttnn/runtime.cpp +++ b/runtime/lib/ttnn/runtime.cpp @@ -1,6 +1,9 @@ // SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC // // SPDX-License-Identifier: Apache-2.0 + +#include "tt/runtime/runtime.h" + #include "tt/runtime/detail/debug.h" #include "tt/runtime/detail/logger.h" #include "tt/runtime/detail/ttnn.h" @@ -12,6 +15,10 @@ #include "ttnn/tensor/shape/small_vector.hpp" #include "ttnn/tensor/types.hpp" +#include +#include +#include + namespace tt::runtime::ttnn { using ::tt::runtime::DeviceRuntime; @@ -154,7 +161,7 @@ Tensor createTensor(Device device, Layout layout, createOwnedTensor(nullptr, shape, stride, itemsize, utils::fromTTNNDataType(layoutDesc.dataType)); Tensor out = utils::createRuntimeTensorFromTTNN(tensor); - return toLayout(out, device, layout); + return ::tt::runtime::ttnn::toLayout(out, device, layout); } DeviceVariant targetDevice = getTargetDevice(device.as<::ttnn::MeshDevice>(DeviceRuntime::TTNN)); @@ -527,4 +534,163 @@ std::vector submit(Device deviceHandle, Binary executableHandle, return outputs; } +std::vector do_stuff(void *so, std::string func_name, + std::vector inputs, Device device) { + + ::ttnn::MeshDevice &ttnnMeshDevice = + device.as<::ttnn::MeshDevice>(DeviceRuntime::TTNN); + + assert(ttnnMeshDevice.get_devices().size() == 1); + + ::ttnn::Device *ttnnDevice = ttnnMeshDevice.get_devices()[0]; + + // Convert inputs to TTNN tensors using .as method + // + std::vector<::ttnn::Tensor> ttnnInputs; + for (auto &input : inputs) { + LOG_ASSERT(input.matchesRuntime(DeviceRuntime::TTNN)); + ttnnInputs.push_back(input.as<::ttnn::Tensor>(DeviceRuntime::TTNN)); + } + + // Clear previous error + // + dlerror(); + + // Get function from shared object + // + using ForwardFunction = std::vector<::ttnn::Tensor> (*)( + std::vector<::ttnn::Tensor>, ::ttnn::Device *); + std::cout << "before" << std::endl; + ForwardFunction forwardFunc = (ForwardFunction)dlsym(so, func_name.c_str()); + std::cout << "after" << std::endl; + + const char *dlsym_error = dlerror(); + if (dlsym_error) { + std::cerr << "Failed to load symbol: " << dlsym_error << std::endl; + dlclose(so); + throw std::runtime_error("Failed to load symbol"); + } + + // Call function + // + std::vector<::ttnn::Tensor> ttnnOutputs = forwardFunc(ttnnInputs, ttnnDevice); + + // Convert outputs to Tensor using Tensor constructor + // + std::vector outputs; + for (::ttnn::Tensor &output : ttnnOutputs) { + // using Storage = std::variant; + if (std::holds_alternative( + output.tensor_attributes->storage)) { + std::cout << "OwnedStorage" << std::endl; + } else if (std::holds_alternative( + output.tensor_attributes->storage)) { + std::cout << "DeviceStorage" << std::endl; + } else if (std::holds_alternative( + output.tensor_attributes->storage)) { + std::cout << "BorrowedStorage" << std::endl; + } else if (std::holds_alternative( + output.tensor_attributes->storage)) { + std::cout << "MultiDeviceHostStorage" << std::endl; + } else if (std::holds_alternative( + output.tensor_attributes->storage)) { + std::cout << "MultiDeviceStorage" << std::endl; + } else { + std::cout << "Unknown" << std::endl; + } + + // BorrowedBuffer borrowedBuffer = + // std::get(output.tensor_attributes->storage).buffer; + // std::visit( + // [&outputs, &output](auto &&buffer) { + // outputs.push_back( + // Tensor(std::make_shared<::ttnn::Tensor>(std::move(output)), + // std::shared_ptr(static_cast(buffer.data()), + // [](void *) {}), + // DeviceRuntime::TTNN)); + // }, + // borrowedBuffer); + + OwnedStorage ownedStorage = + std::get(output.tensor_attributes->storage).buffer; + + std::visit( + [&outputs, &output](auto &&buffer) { + outputs.push_back( + Tensor(std::make_shared<::ttnn::Tensor>(std::move(output)), + std::shared_ptr(static_cast(buffer.data()), + [](void *) {}), + DeviceRuntime::TTNN)); + }, + ownedStorage.get_buffer()); + } + + return outputs; +} + +bool compareOuts(std::vector &lhs, std::vector &rhs) { + std::vector<::ttnn::Tensor *> lhsTensors; + std::vector<::ttnn::Tensor *> rhsTensors; + + for (auto &tensor : lhs) { + lhsTensors.push_back(static_cast<::ttnn::Tensor *>(tensor.handle.get())); + } + for (auto &tensor : rhs) { + rhsTensors.push_back(static_cast<::ttnn::Tensor *>(tensor.handle.get())); + } + + LOG_ASSERT(lhsTensors.size() == rhsTensors.size()); + for (size_t i = 0; i < lhsTensors.size(); i++) { + auto lhsTensor = lhsTensors[i]; + auto rhsTensor = rhsTensors[i]; + std::cout << "Dtype: " << (int)lhsTensor->get_dtype() << ", " + << (int)rhsTensor->get_dtype() << std::endl; + LOG_ASSERT(lhsTensor->get_dtype() == rhsTensor->get_dtype()); + std::cout << "Shape: " << lhsTensor->get_shape() << ", " + << rhsTensor->get_shape() << std::endl; + LOG_ASSERT(lhsTensor->get_shape() == rhsTensor->get_shape()); + std::cout << "Layout: " << (int)lhsTensor->get_layout() << ", " + << (int)rhsTensor->get_layout() << std::endl; + LOG_ASSERT(lhsTensor->get_layout() == rhsTensor->get_layout()); + std::cout << "Logical shape: " << lhsTensor->get_logical_shape() << ", " + << rhsTensor->get_logical_shape() << std::endl; + LOG_ASSERT(lhsTensor->get_logical_shape() == + rhsTensor->get_logical_shape()); + std::cout << "Volume: " << lhsTensor->volume() << ", " + << rhsTensor->volume() << std::endl; + LOG_ASSERT(lhsTensor->volume() == rhsTensor->volume()); + std::cout << "Element size in bytes: " << lhsTensor->element_size() << ", " + << rhsTensor->element_size() << std::endl; + LOG_ASSERT(lhsTensor->element_size() == rhsTensor->element_size()); + + std::cout << "Printing LHS:" << std::endl; + lhsTensor->print(); + std::cout << std::endl << std::endl; + std::cout << "Printing RHS:" << std::endl; + rhsTensor->print(); + + // Compare tensor data + // + uint8_t *lhsData = static_cast( + ::tt::tt_metal::get_raw_host_data_ptr(*lhsTensor)); + uint8_t *rhsData = static_cast( + ::tt::tt_metal::get_raw_host_data_ptr(*rhsTensor)); + + for (size_t i = 0; i < lhsTensor->volume() * lhsTensor->element_size(); + i++) { + if (lhsData[i] != rhsData[i]) { + std::cout << "Mismatch at byte number: " << i << ": " << (int)lhsData[i] + << " != " << (int)rhsData[i] << std::endl; + return false; + } + } + + std::cout << "Done" << std::endl << std::endl; + } + + return true; +} + } // namespace tt::runtime::ttnn diff --git a/runtime/tools/python/ttrt/common/run.py b/runtime/tools/python/ttrt/common/run.py index b83c5d390..7a75cb5ce 100644 --- a/runtime/tools/python/ttrt/common/run.py +++ b/runtime/tools/python/ttrt/common/run.py @@ -154,6 +154,13 @@ def initialize_api(): choices=None, help="test file to save results to", ) + Run.register_arg( + name="--emitc", + type=bool, + default=False, + choices=[True, False], + help="toggles emitc testing", + ) Run.register_arg( name="--disable-golden", type=bool, @@ -413,6 +420,8 @@ def _execute(binaries): get_callback_fn(callback_runtime_config) ) + is_emitc_testing_requested = self["--emitc"] + try: for bin in binaries: try: @@ -421,6 +430,11 @@ def _execute(binaries): if self["--save-artifacts"]: self.artifacts.create_binary_artifacts_folder(bin) + if is_emitc_testing_requested: + emitc_dylib_path = bin.file_path.replace(".ttnn", ".so") + emitc_dylib_handle = ttrt.runtime.open_so(emitc_dylib_path) + self.logging.debug(f"opened emitc dylib={emitc_dylib_path}") + program_indices = [] if self["--program-index"] == "all": program_indices.extend(range(bin.get_num_programs())) @@ -544,6 +558,33 @@ def _execute(binaries): if event is not None: ttrt.runtime.wait(event) + if is_emitc_testing_requested: + fwd_func_name = program.program["name"] + fwd_func_name_len = len(fwd_func_name) + fwd_func_sym = f"_Z{fwd_func_name_len}{fwd_func_name}St6vectorIN2tt8tt_metal6TensorESaIS2_EEPNS1_2v06DeviceE" + emitc_outs = ttrt.runtime.run_so_program( + emitc_dylib_handle, + fwd_func_sym, + inputs, + device, + ) + self.logging.debug( + f"got emitc outputs for program={program_index}" + ) + + all_tensors_match = ttrt.runtime.compare_outs( + total_outputs[0], emitc_outs + ) + + if not all_tensors_match: + self.logging.error( + "Failed: TTRT and EmitC outputs do not match!" + ) + self.logging.error(total_outputs[0], emitc_outs) + raise Exception( + "Failed: TTRT and EmitC outputs do not match!" + ) + if self["--identity"]: self.logging.debug( f"checking identity with rtol={self['--rtol']} and atol={self['--atol']}" diff --git a/runtime/tools/python/ttrt/common/util.py b/runtime/tools/python/ttrt/common/util.py index 3f1b1adf8..ba6c5dcfe 100644 --- a/runtime/tools/python/ttrt/common/util.py +++ b/runtime/tools/python/ttrt/common/util.py @@ -531,10 +531,10 @@ def check_version(self): except Exception as e: raise Exception(f"error retrieving version: {e} for {package_name}") - if package_version != self.version: - raise Exception( - f"{package_name}: v{package_version} does not match flatbuffer: v{self.version} for flatbuffer: {self.file_path} - skipping this test" - ) + # if package_version != self.version: + # raise Exception( + # f"{package_name}: v{package_version} does not match flatbuffer: v{self.version} for flatbuffer: {self.file_path} - skipping this test" + # ) return True diff --git a/runtime/tools/python/ttrt/runtime/__init__.py b/runtime/tools/python/ttrt/runtime/__init__.py index 5ff953da5..8976ed526 100644 --- a/runtime/tools/python/ttrt/runtime/__init__.py +++ b/runtime/tools/python/ttrt/runtime/__init__.py @@ -29,6 +29,9 @@ get_op_debug_str, memcpy, deallocate_tensor, + open_so, + run_so_program, + compare_outs, WorkaroundEnv, get_op_loc_info, unregister_hooks, diff --git a/runtime/tools/python/ttrt/runtime/module.cpp b/runtime/tools/python/ttrt/runtime/module.cpp index 4c3eb8c69..6649d178a 100644 --- a/runtime/tools/python/ttrt/runtime/module.cpp +++ b/runtime/tools/python/ttrt/runtime/module.cpp @@ -175,6 +175,15 @@ PYBIND11_MODULE(_C, m) { "Copy the data from src tensor to dst tensor"); m.def("deallocate_tensor", &tt::runtime::deallocateTensor, py::arg("tensor"), py::arg("force") = false, "Deallocate the tensor memory"); + + m.def("open_so", &tt::runtime::openSo, py::arg("path"), + "Open a shared object file"); + m.def("run_so_program", &tt::runtime::runSoProgram, py::arg("so"), + py::arg("name"), py::arg("inputs"), py::arg("device"), + "Run a program from a shared object file"); + m.def("compare_outs", &tt::runtime::compareOuts, py::arg("lhs"), + py::arg("rhs")); + py::class_(m, "DebugEnv") .def_static("get", &tt::runtime::debug::Env::get) .def("__str__", [](const tt::runtime::debug::Env &env) { diff --git a/test/lit.cfg.py b/test/lit.cfg.py index d65acc7b2..a039ddb57 100644 --- a/test/lit.cfg.py +++ b/test/lit.cfg.py @@ -49,6 +49,13 @@ def set_system_desc_features(system_desc): # system_desc_path: The system desc that is to be used to generate the binary files. config.system_desc_path = os.getenv("SYSTEM_DESC_PATH", "") +# Add `TT_METAL_HOME` to the lit environment. +llvm_config.with_environment( + "TT_METAL_HOME_ORIGINAL", + os.getenv("TT_METAL_HOME"), + append_path=False, +) + if config.system_desc_path: try: import ttrt @@ -98,3 +105,33 @@ def set_system_desc_features(system_desc): ], append_path=True, ) + +# Add `tools/ttnn-standalone` to PATH. +llvm_config.with_environment( + "PATH", + [ + os.path.join(os.getenv("TT_MLIR_HOME"), "tools/ttnn-standalone"), + ], + append_path=True, +) + +# Add `TT_MLIR_HOME` to the lit environment. +llvm_config.with_environment( + "TT_MLIR_HOME", + os.getenv("TT_MLIR_HOME"), + append_path=False, +) + +# Add `ARCH_NAME` to the lit environment. +llvm_config.with_environment( + "ARCH_NAME", + os.getenv("ARCH_NAME"), + append_path=False, +) + +# Add `TT_METAL_HOME` to the lit environment. +llvm_config.with_environment( + "TT_METAL_HOME", + os.getenv("TT_METAL_HOME"), + append_path=False, +) diff --git a/test/ttmlir/Silicon/TTNN/emitc/mnist.mlir b/test/ttmlir/Silicon/TTNN/emitc/mnist.mlir index f3076360e..4fa009717 100644 --- a/test/ttmlir/Silicon/TTNN/emitc/mnist.mlir +++ b/test/ttmlir/Silicon/TTNN/emitc/mnist.mlir @@ -1,6 +1,11 @@ -// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" --ttnn-create-input-gens --convert-ttnn-to-emitc %s > %t.mlir +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %basename_t.ttnn +// RUN: ttmlir-opt --ttnn-modify-signatures-for-dylib --convert-ttnn-to-emitc %t.mlir > %t2.mlir +// RUN: ttmlir-translate --mlir-to-cpp %t2.mlir > %basename_t.cpp +// RUN: compile_dylib.py %basename_t.cpp . +// UNSUPPORTED: true -module @MNISTLinear attributes {tt.system_desc = #tt.system_desc<[{role = host, target_triple = "x86_64-pc-linux-gnu"}], [{arch = , grid = 8x8, l1_size = 1499136, num_dram_channels = 12, dram_channel_size = 1073741824, noc_l1_address_align_bytes = 16, pcie_address_align_bytes = 32, noc_dram_address_align_bytes = 32, l1_unreserved_base = 1024, erisc_l1_unreserved_base = 1024, dram_unreserved_base = 1024, dram_unreserved_end = 1073741824, physical_cores = {worker = [ 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 1x0, 1x1, 1x2, 1x3, 1x4, 1x5, 1x6, 1x7, 2x0, 2x1, 2x2, 2x3, 2x4, 2x5, 2x6, 2x7, 3x0, 3x1, 3x2, 3x3, 3x4, 3x5, 3x6, 3x7, 4x0, 4x1, 4x2, 4x3, 4x4, 4x5, 4x6, 4x7, 5x0, 5x1, 5x2, 5x3, 5x4, 5x5, 5x6, 5x7, 6x0, 6x1, 6x2, 6x3, 6x4, 6x5, 6x6, 6x7, 7x0, 7x1, 7x2, 7x3, 7x4, 7x5, 7x6, 7x7] dram = [ 8x0, 9x0, 10x0, 8x1, 9x1, 10x1, 8x2, 9x2, 10x2, 8x3, 9x3, 10x3]}, supported_data_types = [, , , , , , , , , , , ], supported_tile_sizes = [ 4x16, 16x16, 32x16, 4x32, 16x32, 32x32], num_cbs = 32}], [0], [3 : i32], [ 0x0x0x0]>} { +module @MNISTLinear attributes {} { func.func @forward(%arg0: tensor<1x784xf32> {ttir.name = "input_1"}, %arg1: tensor<784x512xf32> {ttir.name = "linear_relu_stack.0.weight"}, %arg2: tensor<512xf32> {ttir.name = "linear_relu_stack.0.bias"}, %arg3: tensor<512x512xf32> {ttir.name = "linear_relu_stack.2.weight"}, %arg4: tensor<512xf32> {ttir.name = "linear_relu_stack.2.bias"}, %arg5: tensor<512x10xf32> {ttir.name = "linear_relu_stack.4.weight"}, %arg6: tensor<10xf32> {ttir.name = "linear_relu_stack.4.bias"}) -> (tensor<1x10xf32> {ttir.name = "MNISTLinear_350.output_add_981"}) { %0 = tensor.empty() : tensor<1x512xf32> %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<1x784xf32>, tensor<784x512xf32>, tensor<1x512xf32>) -> tensor<1x512xf32> diff --git a/test/ttmlir/Silicon/TTNN/emitc/simple_add.mlir b/test/ttmlir/Silicon/TTNN/emitc/simple_add.mlir index 951b36061..736462227 100644 --- a/test/ttmlir/Silicon/TTNN/emitc/simple_add.mlir +++ b/test/ttmlir/Silicon/TTNN/emitc/simple_add.mlir @@ -1,5 +1,8 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir -// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %basename_t.ttnn +// RUN: ttmlir-opt --ttnn-modify-signatures-for-dylib --convert-ttnn-to-emitc %t.mlir > %t2.mlir +// RUN: ttmlir-translate --mlir-to-cpp %t2.mlir > %basename_t.cpp +// RUN: compile_dylib.py %basename_t.cpp . func.func @add(%arg0: tensor<32x32xbf16>, %arg1: tensor<32x32xbf16>) -> tensor<32x32xbf16> { %0 = tensor.empty() : tensor<32x32xbf16> diff --git a/test/ttmlir/Silicon/TTNN/emitc/two_fns.mlir b/test/ttmlir/Silicon/TTNN/emitc/two_fns.mlir index 8fc4d2e9b..1fc010bdc 100644 --- a/test/ttmlir/Silicon/TTNN/emitc/two_fns.mlir +++ b/test/ttmlir/Silicon/TTNN/emitc/two_fns.mlir @@ -1,5 +1,8 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir -// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %basename_t.ttnn +// RUN: ttmlir-opt --ttnn-modify-signatures-for-dylib --convert-ttnn-to-emitc %t.mlir > %t2.mlir +// RUN: ttmlir-translate --mlir-to-cpp %t2.mlir > %basename_t.cpp +// RUN: compile_dylib.py %basename_t.cpp . func.func @add(%arg0: tensor<32x32xbf16>, %arg1: tensor<32x32xbf16>) -> tensor<32x32xbf16> { %0 = tensor.empty() : tensor<32x32xbf16> diff --git a/tools/ttnn-standalone/CMakeLists.txt b/tools/ttnn-standalone/CMakeLists.txt index bf31fdc36..978c8457c 100644 --- a/tools/ttnn-standalone/CMakeLists.txt +++ b/tools/ttnn-standalone/CMakeLists.txt @@ -48,6 +48,13 @@ if("$ENV{ARCH_NAME}" STREQUAL "") message(FATAL_ERROR "ARCH_NAME is not set") endif() +# Check if TT_METAL_HOME_ORIGINAL is set +if(DEFINED ENV{TT_METAL_HOME_ORIGINAL}) + # Swap TT_METAL_HOME with TT_METAL_HOME_ORIGINAL + set(ENV{TT_METAL_HOME} $ENV{TT_METAL_HOME_ORIGINAL}) + message(STATUS "Swapped TT_METAL_HOME with TT_METAL_HOME_ORIGINAL.") +endif() + message($ENV{TT_METAL_HOME}/tt_metal/third_party/src/firmware/riscv/$ENV{ARCH_NAME}) # Directories to search for headers diff --git a/tools/ttnn-standalone/compile_dylib.py b/tools/ttnn-standalone/compile_dylib.py new file mode 100755 index 000000000..1fbf4cbdb --- /dev/null +++ b/tools/ttnn-standalone/compile_dylib.py @@ -0,0 +1,112 @@ +#!/opt/ttmlir-toolchain/venv/bin/python +# -*- coding: utf-8 -*- +# +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +import sys +import os +import subprocess +import shutil + + +def main(): + if len(sys.argv) != 3: + print("Usage: script.py ") + sys.exit(1) + + for name, value in os.environ.items(): + print("{0}: {1}".format(name, value)) + + cpp_file_path = sys.argv[1] + output_dir = sys.argv[2] + + # Verify the input file exists + if not os.path.isfile(cpp_file_path): + print(f"Error: File '{cpp_file_path}' does not exist.") + sys.exit(1) + + # Verify the output directory exists + if not os.path.isdir(output_dir): + print(f"Error: Directory '{output_dir}' does not exist.") + sys.exit(1) + + # Define the path to the target file + tt_mlir_home = os.environ.get("TT_MLIR_HOME") + if not tt_mlir_home: + print("Error: TT_MLIR_HOME environment variable is not set.") + sys.exit(1) + + target_file_path = os.path.join( + tt_mlir_home, "tools/ttnn-standalone/ttnn-dylib.cpp" + ) + + try: + # Read contents of the input file + with open(cpp_file_path, "r") as source_file: + cpp_content = source_file.read() + + # Overwrite the target file + with open(target_file_path, "w") as target_file: + target_file.write(cpp_content) + + print( + f"Successfully updated {target_file_path} with contents from {cpp_file_path}." + ) + except Exception as e: + print(f"Error while handling files: {e}") + sys.exit(1) + + # Define the commands to be executed + build_dir = os.path.join(tt_mlir_home, "tools/ttnn-standalone/build") + cmake_command = [ + "cmake", + "-G", + "Ninja", + "-B", + build_dir, + "-S", + os.path.join(tt_mlir_home, "tools/ttnn-standalone"), + "-DCMAKE_BUILD_TYPE=Release", + "-DCMAKE_C_COMPILER=clang", + "-DCMAKE_CXX_COMPILER=clang++", + ] + + build_command = ["cmake", "--build", build_dir, "--", "ttnn-dylib"] + + try: + # Run the cmake command + print("Running cmake command...") + subprocess.run(cmake_command, check=True, cwd=tt_mlir_home) + + # Run the build command + print("Building ttnn-dylib...") + subprocess.run(build_command, check=True, cwd=tt_mlir_home) + + print("Build completed successfully.") + + # Determine the output .so file + compiled_so_path = os.path.join(build_dir, "libttnn-dylib.so") + if not os.path.isfile(compiled_so_path): + print(f"Error: Compiled file '{compiled_so_path}' not found.") + sys.exit(1) + + # Define the destination path with renamed file + output_file_name = os.path.basename(cpp_file_path) + output_file_name = os.path.splitext(output_file_name)[0] + ".so" + destination_path = os.path.join(output_dir, output_file_name) + + # Copy and rename the .so file + shutil.copy2(compiled_so_path, destination_path) + print(f"Successfully copied compiled file to {destination_path}.") + except subprocess.CalledProcessError as e: + print(f"Error during build process: {e}") + sys.exit(1) + except Exception as e: + print(f"Error during file operations: {e}") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/tools/ttnn-standalone/ttnn-dylib.cpp b/tools/ttnn-standalone/ttnn-dylib.cpp index 1993b7de3..1f976da26 100644 --- a/tools/ttnn-standalone/ttnn-dylib.cpp +++ b/tools/ttnn-standalone/ttnn-dylib.cpp @@ -2,44 +2,46 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "ttnn-dylib.hpp" +template +std::vector utilCreateVec(T &&...t) { + return std::vector{std::forward(t)...}; +} -// Forward function example -// -std::vector forward(std::vector inputs) { - ttnn::Tensor v1 = inputs[0]; - ttnn::Tensor v2 = inputs[1]; - ttnn::Device *v3 = ttnn::DeviceGetter::getInstance(); - ttnn::MemoryConfig v4 = ttnn::MemoryConfig( +#include "ttnn-precompiled.hpp" +std::vector add(std::vector v1, ttnn::Device *v2) { + ttnn::Tensor v3 = v1[0]; + ttnn::Tensor v4 = v1[1]; + ttnn::MemoryConfig v5 = ttnn::MemoryConfig( ttnn::TensorMemoryLayout::INTERLEAVED, ttnn::BufferType::DRAM); - ttnn::Tensor v5 = ttnn::to_device(v1, v3, v4); - ttnn::Tensor v6 = - ttnn::to_layout(v5, ttnn::Layout::TILE, std::nullopt, std::nullopt, + ttnn::Tensor v6 = ttnn::to_device(v3, v2, v5); + ttnn::Tensor v7 = + ttnn::to_layout(v6, ttnn::Layout::TILE, std::nullopt, std::nullopt, static_cast<::ttnn::Device *>(nullptr)); - ttnn::deallocate(v5, false); - ttnn::MemoryConfig v7 = ttnn::MemoryConfig( + ttnn::deallocate(v6, false); + ttnn::MemoryConfig v8 = ttnn::MemoryConfig( ttnn::TensorMemoryLayout::INTERLEAVED, ttnn::BufferType::DRAM); - ttnn::Tensor v8 = ttnn::to_device(v2, v3, v7); - ttnn::Tensor v9 = - ttnn::to_layout(v8, ttnn::Layout::TILE, std::nullopt, std::nullopt, + ttnn::Tensor v9 = ttnn::to_device(v4, v2, v8); + ttnn::Tensor v10 = + ttnn::to_layout(v9, ttnn::Layout::TILE, std::nullopt, std::nullopt, static_cast<::ttnn::Device *>(nullptr)); - ttnn::deallocate(v8, false); - ttnn::Shape v10 = ttnn::Shape(tt::tt_metal::LegacyShape({ + ttnn::deallocate(v9, false); + ttnn::Shape v11 = ttnn::Shape(tt::tt_metal::LegacyShape({ 32, 32, })); - ttnn::MemoryConfig v11 = ttnn::MemoryConfig( + ttnn::MemoryConfig v12 = ttnn::MemoryConfig( ttnn::TensorMemoryLayout::INTERLEAVED, ttnn::BufferType::DRAM); - ttnn::Tensor v12 = - ttnn::empty(v10, ttnn::DataType::BFLOAT16, ttnn::Layout::TILE, v3, v11); - ttnn::Tensor v13 = ttnn::add(v6, v9, std::nullopt, std::nullopt, v12); - ttnn::deallocate(v9, false); - ttnn::deallocate(v6, false); - ttnn::Tensor v14 = ttnn::from_device(v13); - ttnn::deallocate(v12, false); - ttnn::Tensor v15 = - ttnn::to_layout(v14, ttnn::Layout::ROW_MAJOR, std::nullopt, std::nullopt, + ttnn::Tensor v13 = + ttnn::empty(v11, ttnn::DataType::BFLOAT16, ttnn::Layout::TILE, v2, v12); + ttnn::Tensor v14 = ttnn::add(v7, v10, std::nullopt, std::nullopt, v13); + ttnn::deallocate(v10, false); + ttnn::deallocate(v7, false); + ttnn::Tensor v15 = ttnn::from_device(v14); + ttnn::deallocate(v13, false); + ttnn::Tensor v16 = + ttnn::to_layout(v15, ttnn::Layout::ROW_MAJOR, std::nullopt, std::nullopt, static_cast<::ttnn::Device *>(nullptr)); - ttnn::deallocate(v14, false); - return std::vector{v15}; + ttnn::deallocate(v15, false); + std::vector v17 = utilCreateVec(v16); + return v17; } diff --git a/tools/ttnn-standalone/ttnn-dylib.hpp b/tools/ttnn-standalone/ttnn-dylib.hpp index 26cbeb717..1ba749ef8 100644 --- a/tools/ttnn-standalone/ttnn-dylib.hpp +++ b/tools/ttnn-standalone/ttnn-dylib.hpp @@ -4,4 +4,5 @@ #include "ttnn-precompiled.hpp" -std::vector forward(std::vector inputs); +std::vector forward(std::vector inputs, + ttnn::Device *device);