Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
svuckovicTT committed Dec 27, 2024
1 parent 6d04d25 commit c6da33c
Show file tree
Hide file tree
Showing 17 changed files with 487 additions and 44 deletions.
16 changes: 12 additions & 4 deletions .github/workflows/build-and-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -265,10 +265,11 @@ jobs:
fail-fast: false
matrix:
build: [
{runs-on: n150, enable_perf: OFF, name: "run", ttrt_flags: "--non-zero"},
{runs-on: n150, enable_perf: ON, name: "perf"},
{runs-on: n300, enable_perf: OFF, name: "run", ttrt_flags: "--non-zero"},
{runs-on: n300, enable_perf: ON, name: "perf"},
{runs-on: n150, enable_perf: OFF, emitc: OFF, name: "run", ttrt_flags: "--non-zero"},
{runs-on: n150, enable_perf: ON, emitc: OFF, name: "perf"},
{runs-on: n150, enable_perf: OFF, emitc: ON, name: "run", ttrt_flags: "--emitc"},
{runs-on: n300, enable_perf: OFF, emitc: OFF, name: "run", ttrt_flags: "--non-zero"},
{runs-on: n300, enable_perf: ON, emitc: OFF, name: "perf"},
]
name: "run-tests (${{ matrix.build.runs-on }}, ${{ matrix.build.enable_perf }}, ${{ matrix.build.name }})"

Expand Down Expand Up @@ -374,6 +375,13 @@ jobs:
ttrt ${{ matrix.build.name }} ${{ matrix.build.ttrt_flags }} ${{ steps.strings.outputs.build-output-dir }}/test/ttmlir/Silicon/TTNN/perf_unit
cp ttrt_report.xml ${{ steps.strings.outputs.test_report_path }}
- name: Run emitc tests
shell: bash
if: matrix.build.emitc == 'ON'
run: |
source env/activate
ttrt ${{ matrix.build.name }} ${{ matrix.build.ttrt_flags }} ${{ steps.strings.outputs.build-output-dir }}/test/ttmlir/Silicon/TTNN/emitc
- name: Upload ttrt test report json
if: always()
uses: actions/upload-artifact@v4
Expand Down
5 changes: 5 additions & 0 deletions runtime/include/tt/runtime/detail/ttnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,11 @@ std::vector<Tensor> runProgram(::ttnn::MeshDevice &meshDevice,
std::uint32_t programIndex,
std::vector<::ttnn::Tensor *> const &inputs);

bool compareOuts(std::vector<Tensor> &lhs, std::vector<Tensor> &rhs);

std::vector<Tensor> do_stuff(void *so, std::string func_name,
std::vector<Tensor> inputs, Device device);

} // namespace tt::runtime::ttnn

#endif
7 changes: 7 additions & 0 deletions runtime/include/tt/runtime/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,13 @@ Event submit(Device deviceHandle, Binary executableHandle,
std::uint32_t programIndex, std::vector<Tensor> const &inputs,
std::vector<Tensor> const &outputs);

void *openSo(std::string path);

std::vector<Tensor> runSoProgram(void *so, std::string name,
std::vector<Tensor> inputs, Device device);

bool compareOuts(std::vector<Tensor> &lhs, std::vector<Tensor> &rhs);

} // namespace tt::runtime

#endif
34 changes: 34 additions & 0 deletions runtime/lib/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
#if defined(TT_RUNTIME_ENABLE_TTNN)
#include "tt/runtime/detail/ttnn.h"
#include "tt/runtime/ttnn/types.h"

#include <dlfcn.h>
#endif

#if defined(TT_RUNTIME_ENABLE_TTMETAL)
Expand Down Expand Up @@ -486,4 +488,36 @@ Event submit(Device deviceHandle, Binary executableHandle,
#endif
LOG_FATAL("runtime is not enabled");
}

void *openSo(std::string path) {
#if defined(TT_RUNTIME_ENABLE_TTNN)
if (getCurrentRuntime() == DeviceRuntime::TTNN) {
void *handle = dlopen(path.c_str(), RTLD_LAZY);
if (!handle) {
std::cerr << "Failed to load shared object: " << dlerror() << std::endl;
throw std::runtime_error("Failed to load shared object");
}

dlerror();
return handle;
}
#endif
throw std::runtime_error("ttnn runtime is not enabled");
}

std::vector<Tensor> runSoProgram(void *so, std::string name,
std::vector<Tensor> inputs, Device device) {

return ::tt::runtime::ttnn::do_stuff(so, name, inputs, device);
}

bool compareOuts(std::vector<Tensor> &lhs, std::vector<Tensor> &rhs) {
#if defined(TT_RUNTIME_ENABLE_TTNN)
if (getCurrentRuntime() == DeviceRuntime::TTNN) {
return ::tt::runtime::ttnn::compareOuts(lhs, rhs);
}
#endif
throw std::runtime_error("ttnn runtime is not enabled");
}

} // namespace tt::runtime
168 changes: 167 additions & 1 deletion runtime/lib/ttnn/runtime.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#include "tt/runtime/runtime.h"

#include "tt/runtime/detail/debug.h"
#include "tt/runtime/detail/logger.h"
#include "tt/runtime/detail/ttnn.h"
Expand All @@ -12,6 +15,10 @@
#include "ttnn/tensor/shape/small_vector.hpp"
#include "ttnn/tensor/types.hpp"

#include <cstdint>
#include <dlfcn.h>
#include <memory>

namespace tt::runtime::ttnn {

using ::tt::runtime::DeviceRuntime;
Expand Down Expand Up @@ -154,7 +161,7 @@ Tensor createTensor(Device device, Layout layout,
createOwnedTensor(nullptr, shape, stride, itemsize,
utils::fromTTNNDataType(layoutDesc.dataType));
Tensor out = utils::createRuntimeTensorFromTTNN(tensor);
return toLayout(out, device, layout);
return ::tt::runtime::ttnn::toLayout(out, device, layout);
}
DeviceVariant targetDevice =
getTargetDevice(device.as<::ttnn::MeshDevice>(DeviceRuntime::TTNN));
Expand Down Expand Up @@ -527,4 +534,163 @@ std::vector<Tensor> submit(Device deviceHandle, Binary executableHandle,
return outputs;
}

std::vector<Tensor> do_stuff(void *so, std::string func_name,
std::vector<Tensor> inputs, Device device) {

::ttnn::MeshDevice &ttnnMeshDevice =
device.as<::ttnn::MeshDevice>(DeviceRuntime::TTNN);

assert(ttnnMeshDevice.get_devices().size() == 1);

::ttnn::Device *ttnnDevice = ttnnMeshDevice.get_devices()[0];

// Convert inputs to TTNN tensors using .as method
//
std::vector<::ttnn::Tensor> ttnnInputs;
for (auto &input : inputs) {
LOG_ASSERT(input.matchesRuntime(DeviceRuntime::TTNN));
ttnnInputs.push_back(input.as<::ttnn::Tensor>(DeviceRuntime::TTNN));
}

// Clear previous error
//
dlerror();

// Get function from shared object
//
using ForwardFunction = std::vector<::ttnn::Tensor> (*)(
std::vector<::ttnn::Tensor>, ::ttnn::Device *);
std::cout << "before" << std::endl;
ForwardFunction forwardFunc = (ForwardFunction)dlsym(so, func_name.c_str());
std::cout << "after" << std::endl;

const char *dlsym_error = dlerror();
if (dlsym_error) {
std::cerr << "Failed to load symbol: " << dlsym_error << std::endl;
dlclose(so);
throw std::runtime_error("Failed to load symbol");
}

// Call function
//
std::vector<::ttnn::Tensor> ttnnOutputs = forwardFunc(ttnnInputs, ttnnDevice);

// Convert outputs to Tensor using Tensor constructor
//
std::vector<Tensor> outputs;
for (::ttnn::Tensor &output : ttnnOutputs) {
// using Storage = std::variant<OwnedStorage, DeviceStorage,
// BorrowedStorage, MultiDeviceHostStorage, MultiDeviceStorage>;
if (std::holds_alternative<OwnedStorage>(
output.tensor_attributes->storage)) {
std::cout << "OwnedStorage" << std::endl;
} else if (std::holds_alternative<DeviceStorage>(
output.tensor_attributes->storage)) {
std::cout << "DeviceStorage" << std::endl;
} else if (std::holds_alternative<BorrowedStorage>(
output.tensor_attributes->storage)) {
std::cout << "BorrowedStorage" << std::endl;
} else if (std::holds_alternative<MultiDeviceHostStorage>(
output.tensor_attributes->storage)) {
std::cout << "MultiDeviceHostStorage" << std::endl;
} else if (std::holds_alternative<MultiDeviceStorage>(
output.tensor_attributes->storage)) {
std::cout << "MultiDeviceStorage" << std::endl;
} else {
std::cout << "Unknown" << std::endl;
}

// BorrowedBuffer borrowedBuffer =
// std::get<BorrowedStorage>(output.tensor_attributes->storage).buffer;
// std::visit(
// [&outputs, &output](auto &&buffer) {
// outputs.push_back(
// Tensor(std::make_shared<::ttnn::Tensor>(std::move(output)),
// std::shared_ptr<void>(static_cast<void
// *>(buffer.data()),
// [](void *) {}),
// DeviceRuntime::TTNN));
// },
// borrowedBuffer);

OwnedStorage ownedStorage =
std::get<OwnedStorage>(output.tensor_attributes->storage).buffer;

std::visit(
[&outputs, &output](auto &&buffer) {
outputs.push_back(
Tensor(std::make_shared<::ttnn::Tensor>(std::move(output)),
std::shared_ptr<void>(static_cast<void *>(buffer.data()),
[](void *) {}),
DeviceRuntime::TTNN));
},
ownedStorage.get_buffer());
}

return outputs;
}

bool compareOuts(std::vector<Tensor> &lhs, std::vector<Tensor> &rhs) {
std::vector<::ttnn::Tensor *> lhsTensors;
std::vector<::ttnn::Tensor *> rhsTensors;

for (auto &tensor : lhs) {
lhsTensors.push_back(static_cast<::ttnn::Tensor *>(tensor.handle.get()));
}
for (auto &tensor : rhs) {
rhsTensors.push_back(static_cast<::ttnn::Tensor *>(tensor.handle.get()));
}

LOG_ASSERT(lhsTensors.size() == rhsTensors.size());
for (size_t i = 0; i < lhsTensors.size(); i++) {
auto lhsTensor = lhsTensors[i];
auto rhsTensor = rhsTensors[i];
std::cout << "Dtype: " << (int)lhsTensor->get_dtype() << ", "
<< (int)rhsTensor->get_dtype() << std::endl;
LOG_ASSERT(lhsTensor->get_dtype() == rhsTensor->get_dtype());
std::cout << "Shape: " << lhsTensor->get_shape() << ", "
<< rhsTensor->get_shape() << std::endl;
LOG_ASSERT(lhsTensor->get_shape() == rhsTensor->get_shape());
std::cout << "Layout: " << (int)lhsTensor->get_layout() << ", "
<< (int)rhsTensor->get_layout() << std::endl;
LOG_ASSERT(lhsTensor->get_layout() == rhsTensor->get_layout());
std::cout << "Logical shape: " << lhsTensor->get_logical_shape() << ", "
<< rhsTensor->get_logical_shape() << std::endl;
LOG_ASSERT(lhsTensor->get_logical_shape() ==
rhsTensor->get_logical_shape());
std::cout << "Volume: " << lhsTensor->volume() << ", "
<< rhsTensor->volume() << std::endl;
LOG_ASSERT(lhsTensor->volume() == rhsTensor->volume());
std::cout << "Element size in bytes: " << lhsTensor->element_size() << ", "
<< rhsTensor->element_size() << std::endl;
LOG_ASSERT(lhsTensor->element_size() == rhsTensor->element_size());

std::cout << "Printing LHS:" << std::endl;
lhsTensor->print();
std::cout << std::endl << std::endl;
std::cout << "Printing RHS:" << std::endl;
rhsTensor->print();

// Compare tensor data
//
uint8_t *lhsData = static_cast<uint8_t *>(
::tt::tt_metal::get_raw_host_data_ptr(*lhsTensor));
uint8_t *rhsData = static_cast<uint8_t *>(
::tt::tt_metal::get_raw_host_data_ptr(*rhsTensor));

for (size_t i = 0; i < lhsTensor->volume() * lhsTensor->element_size();
i++) {
if (lhsData[i] != rhsData[i]) {
std::cout << "Mismatch at byte number: " << i << ": " << (int)lhsData[i]
<< " != " << (int)rhsData[i] << std::endl;
return false;
}
}

std::cout << "Done" << std::endl << std::endl;
}

return true;
}

} // namespace tt::runtime::ttnn
41 changes: 41 additions & 0 deletions runtime/tools/python/ttrt/common/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,13 @@ def initialize_api():
choices=None,
help="test file to save results to",
)
Run.register_arg(
name="--emitc",
type=bool,
default=False,
choices=[True, False],
help="toggles emitc testing",
)
Run.register_arg(
name="--disable-golden",
type=bool,
Expand Down Expand Up @@ -413,6 +420,8 @@ def _execute(binaries):
get_callback_fn(callback_runtime_config)
)

is_emitc_testing_requested = self["--emitc"]

try:
for bin in binaries:
try:
Expand All @@ -421,6 +430,11 @@ def _execute(binaries):
if self["--save-artifacts"]:
self.artifacts.create_binary_artifacts_folder(bin)

if is_emitc_testing_requested:
emitc_dylib_path = bin.file_path.replace(".ttnn", ".so")
emitc_dylib_handle = ttrt.runtime.open_so(emitc_dylib_path)
self.logging.debug(f"opened emitc dylib={emitc_dylib_path}")

program_indices = []
if self["--program-index"] == "all":
program_indices.extend(range(bin.get_num_programs()))
Expand Down Expand Up @@ -544,6 +558,33 @@ def _execute(binaries):
if event is not None:
ttrt.runtime.wait(event)

if is_emitc_testing_requested:
fwd_func_name = program.program["name"]
fwd_func_name_len = len(fwd_func_name)
fwd_func_sym = f"_Z{fwd_func_name_len}{fwd_func_name}St6vectorIN2tt8tt_metal6TensorESaIS2_EEPNS1_2v06DeviceE"
emitc_outs = ttrt.runtime.run_so_program(
emitc_dylib_handle,
fwd_func_sym,
inputs,
device,
)
self.logging.debug(
f"got emitc outputs for program={program_index}"
)

all_tensors_match = ttrt.runtime.compare_outs(
total_outputs[0], emitc_outs
)

if not all_tensors_match:
self.logging.error(
"Failed: TTRT and EmitC outputs do not match!"
)
self.logging.error(total_outputs[0], emitc_outs)
raise Exception(
"Failed: TTRT and EmitC outputs do not match!"
)

if self["--identity"]:
self.logging.debug(
f"checking identity with rtol={self['--rtol']} and atol={self['--atol']}"
Expand Down
8 changes: 4 additions & 4 deletions runtime/tools/python/ttrt/common/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,10 +531,10 @@ def check_version(self):
except Exception as e:
raise Exception(f"error retrieving version: {e} for {package_name}")

if package_version != self.version:
raise Exception(
f"{package_name}: v{package_version} does not match flatbuffer: v{self.version} for flatbuffer: {self.file_path} - skipping this test"
)
# if package_version != self.version:
# raise Exception(
# f"{package_name}: v{package_version} does not match flatbuffer: v{self.version} for flatbuffer: {self.file_path} - skipping this test"
# )

return True

Expand Down
3 changes: 3 additions & 0 deletions runtime/tools/python/ttrt/runtime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
get_op_debug_str,
memcpy,
deallocate_tensor,
open_so,
run_so_program,
compare_outs,
WorkaroundEnv,
get_op_loc_info,
unregister_hooks,
Expand Down
Loading

0 comments on commit c6da33c

Please sign in to comment.