Skip to content

Commit

Permalink
#1190: Added runtime support for doing golden comparision for flatbuf…
Browse files Browse the repository at this point in the history
…fers in ttrt (#1218)
  • Loading branch information
tapspatel authored Nov 15, 2024
1 parent 3617528 commit 2f1078c
Show file tree
Hide file tree
Showing 20 changed files with 656 additions and 66 deletions.
2 changes: 1 addition & 1 deletion docs/src/ttrt.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ ttrt query --save-artifacts
4. Use ttmlir-opt tool in compiler to feed system descriptor. See the [ttmlir-opt](./ttmlir-opt.md) documentation for more information on how to generate .mlir files.
```bash
./build/bin/ttmlir-opt --ttir-load-system-desc="path=/path/to/system_desc.ttsys" --ttir-to-ttnn-backend-pipeline test/ttmlir/Dialect/TTNN/simple_subtract.mlir -o ttnn.mlir
or (pip path directly into ttir-to-ttnn-backend-pipeline)
or (pipe path directly into ttir-to-ttnn-backend-pipeline)
./build/bin/ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=/path/to/system_desc.ttsys" test/ttmlir/Dialect/TTNN/simple_subtract.mlir -o ttnn.mlir
```
5. Use ttmlir-translate tool in compiler to generate the flatbuffer executable. See the [ttmlir-translate](./ttmlir-translate.md) documentation for more information on how to generate flatbuffer files.
Expand Down
4 changes: 1 addition & 3 deletions python/test_infra/ttir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,9 +304,7 @@ def eltwise_proxy(
inputs,
[output],
self._get_operand_constraint_attr(3),
loc=Location.file(
filename=stack[0].filename, line=stack[0].lineno, col=id
),
loc=Location.name(str(id)),
)

goldens = []
Expand Down
44 changes: 44 additions & 0 deletions runtime/include/tt/runtime/detail/debug.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@
#ifndef TT_RUNTIME_DETAIL_DEBUG_H
#define TT_RUNTIME_DETAIL_DEBUG_H

#include <functional>
#include <optional>
#include <ostream>

#include "tt/runtime/types.h"

namespace tt::runtime::debug {

struct Env {
Expand Down Expand Up @@ -41,6 +45,46 @@ inline std::ostream &operator<<(std::ostream &os, Env const &env) {
return os;
}

struct Hooks {
#if defined(TT_RUNTIME_DEBUG) && TT_RUNTIME_DEBUG == 1
static Hooks const &
get(std::optional<std::function<void(Binary, CallbackContext, OpContext)>>
operatorCallback = std::nullopt);
#else
constexpr static Hooks get() { return Hooks(); }
#endif

std::optional<std::function<void(Binary, CallbackContext, OpContext)>>
getOperatorCallback() const {
#if defined(TT_RUNTIME_DEBUG) && TT_RUNTIME_DEBUG == 1
return operatorCallback;
#else
return std::nullopt;
#endif
}

private:
#if defined(TT_RUNTIME_DEBUG) && TT_RUNTIME_DEBUG == 1
Hooks(std::optional<std::function<void(Binary, CallbackContext, OpContext)>>
operatorCallback)
: operatorCallback(operatorCallback) {}

std::optional<std::function<void(Binary, CallbackContext, OpContext)>>
operatorCallback;
#else
constexpr Hooks() = default;
#endif
};

inline std::ostream &operator<<(std::ostream &os, Hooks const &hooks) {
os << "debug::Hooks{\n"
<< "\t"
<< "operatorCallback: " << static_cast<bool>(hooks.getOperatorCallback())
<< ",\n"
<< "}";
return os;
}

} // namespace tt::runtime::debug

#endif // TT_RUNTIME_DETAIL_DEBUG_H
7 changes: 7 additions & 0 deletions runtime/include/tt/runtime/detail/ttmetal.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,13 @@ Event submit(Device device, Binary executable, std::uint32_t programIndex,

void wait(Event event);

std::string getOpDebugString(OpContext opContextHandle);

Tensor getOpOutputTensor(OpContext opContextHandle,
CallbackContext programContextHandle);

std::vector<float> getTensorData(Tensor tensor);

using InputBuffer =
std::tuple<std::uint32_t, std::shared_ptr<::tt::tt_metal::Buffer>,
std::shared_ptr<::tt::tt_metal::Event>>;
Expand Down
11 changes: 9 additions & 2 deletions runtime/include/tt/runtime/detail/ttnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,15 @@ Event submit(Device device, Binary executable, std::uint32_t programIndex,

void wait(Event event);

void runProgram(::ttnn::MeshDevice &meshDevice,
::tt::target::ttnn::Program const *program,
std::string getOpDebugString(OpContext opContextHandle);

Tensor getOpOutputTensor(OpContext opContextHandle,
CallbackContext programContextHandle);

std::vector<float> getTensorData(Tensor tensor);

void runProgram(::ttnn::MeshDevice &meshDevice, Binary &executableHandle,
std::uint32_t programIndex,
std::vector<::ttnn::Tensor *> const &inputs,
std::vector<::ttnn::Tensor *> const &outputs);

Expand Down
7 changes: 7 additions & 0 deletions runtime/include/tt/runtime/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,13 @@ Event submit(Device device, Binary executable, std::uint32_t programIndex,

void wait(Event event);

std::string getOpDebugString(OpContext opContextHandle);

Tensor getOpOutputTensor(OpContext opContextHandle,
CallbackContext programContextHandle);

std::vector<float> getTensorData(Tensor tensor);

} // namespace tt::runtime

#endif
11 changes: 11 additions & 0 deletions runtime/include/tt/runtime/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wcovered-switch-default"
#include "ttmlir/Target/Common/debug_info_generated.h"
#include "ttmlir/Target/Common/system_desc_generated.h"
#include "ttmlir/Target/Common/types_generated.h"
#pragma clang diagnostic pop
Expand Down Expand Up @@ -108,6 +109,7 @@ struct Binary : public Flatbuffer {

std::vector<TensorDesc> getProgramInputs(std::uint32_t programIndex) const;
std::vector<TensorDesc> getProgramOutputs(std::uint32_t programIndex) const;
const ::tt::target::GoldenTensor *getDebugInfoGolden(std::string &loc) const;
};

struct Device : public detail::RuntimeCheckedObjectImpl {
Expand All @@ -120,11 +122,20 @@ struct Event : public detail::RuntimeCheckedObjectImpl {

struct Tensor : public detail::RuntimeCheckedObjectImpl {
std::shared_ptr<void> data;

Tensor(std::shared_ptr<void> handle, std::shared_ptr<void> data,
DeviceRuntime runtime)
: detail::RuntimeCheckedObjectImpl(handle, runtime), data(data) {}
};

struct CallbackContext : public detail::RuntimeCheckedObjectImpl {
using detail::RuntimeCheckedObjectImpl::RuntimeCheckedObjectImpl;
};

struct OpContext : public detail::RuntimeCheckedObjectImpl {
using detail::RuntimeCheckedObjectImpl::RuntimeCheckedObjectImpl;
};

} // namespace tt::runtime

#endif
39 changes: 39 additions & 0 deletions runtime/lib/binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,23 @@ std::vector<TensorDesc> getProgramOutputs(Flatbuffer binary,
return outputs;
}

const ::tt::target::GoldenTensor *getDebugInfoGolden(Flatbuffer binary,
std::string &loc) {
auto const *programs = getBinary(binary)->programs();
for (auto const *program : *programs) {
for (const ::tt::target::GoldenKV *goldenKV :
*program->debug_info()->golden_info()->golden_map()) {
if (std::string(goldenKV->key()->c_str()) == loc) {
return goldenKV->value();
;
}
}
}

LOG_WARNING("Golden information not found");
return nullptr;
}

} // namespace ttnn

namespace metal {
Expand Down Expand Up @@ -177,6 +194,12 @@ std::vector<TensorDesc> getProgramOutputs(Flatbuffer binary,
return outputs;
}

const ::tt::target::GoldenTensor *getDebugInfoGolden(Flatbuffer binary,
std::string &loc) {
LOG_WARNING("Debug golden information not enabled for metal yet!");
return nullptr;
}

} // namespace metal

namespace system_desc {
Expand Down Expand Up @@ -344,4 +367,20 @@ Binary::getProgramOutputs(std::uint32_t programIndex) const {
throw std::runtime_error("Unsupported binary format");
}

const ::tt::target::GoldenTensor *
Binary::getDebugInfoGolden(std::string &loc) const {
if (::tt::target::ttnn::SizePrefixedTTNNBinaryBufferHasIdentifier(
handle.get())) {
return ttnn::getDebugInfoGolden(*this, loc);
}

if (::tt::target::metal::SizePrefixedTTMetalBinaryBufferHasIdentifier(
handle.get())) {
return metal::getDebugInfoGolden(*this, loc);
}

throw std::runtime_error(
"Unsupported binary format for obtaining golden information");
}

} // namespace tt::runtime
11 changes: 11 additions & 0 deletions runtime/lib/common/debug.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,17 @@ Env const &Env::get(bool loadKernelsFromDisk, bool enableAsyncTTNN) {
return config;
}

#if defined(TT_RUNTIME_DEBUG) && TT_RUNTIME_DEBUG == 1
Hooks const &Hooks::get(
std::optional<std::function<void(Binary, CallbackContext, OpContext)>>
operatorCallback) {
static Hooks config(operatorCallback);
return config;
}
#else
Hooks get() { return Hooks(); }
#endif

} // namespace tt::runtime::debug

#endif
50 changes: 50 additions & 0 deletions runtime/lib/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#if defined(TT_RUNTIME_ENABLE_TTNN)
#include "tt/runtime/detail/ttnn.h"
#include "tt/runtime/ttnn/types.h"
#endif

#if defined(TT_RUNTIME_ENABLE_TTMETAL)
Expand Down Expand Up @@ -245,4 +246,53 @@ void wait(Event event) {
throw std::runtime_error("runtime is not enabled");
}

std::string getOpDebugString(OpContext opContextHandle) {
#if defined(TT_RUNTIME_ENABLE_TTNN)
if (getCurrentRuntime() == DeviceRuntime::TTNN) {
return ::tt::runtime::ttnn::getOpDebugString(opContextHandle);
}
#endif

#if defined(TT_RUNTIME_ENABLE_TTMETAL)
if (getCurrentRuntime() == DeviceRuntime::TTMetal) {
return ::tt::runtime::ttmetal::getOpDebugString(opContextHandle);
}
#endif
throw std::runtime_error("runtime is not enabled");
}

Tensor getOpOutputTensor(OpContext opContextHandle,
CallbackContext programContextHandle) {
#if defined(TT_RUNTIME_ENABLE_TTNN)
if (getCurrentRuntime() == DeviceRuntime::TTNN) {
return ::tt::runtime::ttnn::getOpOutputTensor(opContextHandle,
programContextHandle);
}
#endif

#if defined(TT_RUNTIME_ENABLE_TTMETAL)
if (getCurrentRuntime() == DeviceRuntime::TTMetal) {
return ::tt::runtime::ttmetal::getOpOutputTensor(opContextHandle,
programContextHandle);
}
#endif
throw std::runtime_error("runtime is not enabled");
}

std::vector<float> getTensorData(Tensor tensor) {
#if defined(TT_RUNTIME_ENABLE_TTNN)
if (getCurrentRuntime() == DeviceRuntime::TTNN) {
return ::tt::runtime::ttnn::getTensorData(tensor);
}
#endif

#if defined(TT_RUNTIME_ENABLE_TTMETAL)
if (getCurrentRuntime() == DeviceRuntime::TTMetal) {
return ::tt::runtime::ttmetal::getTensorData(tensor);
}
#endif

throw std::runtime_error("runtime is not enabled");
}

} // namespace tt::runtime
23 changes: 23 additions & 0 deletions runtime/lib/ttmetal/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ static ::tt::target::metal::TTMetalBinary const *getBinary(Flatbuffer binary) {
return ::tt::target::metal::GetSizePrefixedTTMetalBinary(binary.handle.get());
}

static Tensor createNullTensor() {
return Tensor(nullptr, nullptr, DeviceRuntime::TTMetal);
}

Tensor createTensor(std::shared_ptr<void> data,
std::vector<std::uint32_t> const &shape,
std::vector<std::uint32_t> const &stride,
Expand Down Expand Up @@ -252,4 +256,23 @@ void wait(Event event) {
}
}

std::string getOpDebugString(OpContext opContextHandle) {
// Not implemented
LOG_WARNING("obtaining op debug string for metal runtime not implemented");
return "";
}

Tensor getOpOutputTensor(OpContext opContextHandle,
CallbackContext programContextHandle) {
// Not implemented
LOG_WARNING("obtaining op output tensor for metal runtime not implemented");
return createNullTensor();
}

std::vector<float> getTensorData(Tensor tensor) {
// Not implemented
LOG_WARNING("obtaining tensor data for metal runtime not implemented");
return {};
}

} // namespace tt::runtime::ttmetal
6 changes: 6 additions & 0 deletions runtime/lib/ttnn/include/tt/runtime/ttnn/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ class ProgramTensorPool {
return *liveTensors.at(globalId);
}

const ::ttnn::Tensor &at(std::uint32_t globalId) const {
assert(liveTensors.contains(globalId));
return *liveTensors.at(globalId);
}

size_t erase(std::uint32_t globalId) {
assert(liveTensors.contains(globalId) &&
intermedTensors.contains(globalId));
Expand Down Expand Up @@ -161,6 +166,7 @@ class ProgramContext {
// Tensor Pool Operations
//
ProgramTensorPool &getTensorPool() { return tensorPool; }
const ProgramTensorPool &getTensorPool() const { return tensorPool; }

private:
ProgramTensorPool tensorPool;
Expand Down
Loading

0 comments on commit 2f1078c

Please sign in to comment.