diff --git a/docs/src/ttrt.md b/docs/src/ttrt.md index d5d6d87c7..b7808e237 100644 --- a/docs/src/ttrt.md +++ b/docs/src/ttrt.md @@ -58,7 +58,7 @@ ttrt query --save-artifacts 4. Use ttmlir-opt tool in compiler to feed system descriptor. See the [ttmlir-opt](./ttmlir-opt.md) documentation for more information on how to generate .mlir files. ```bash ./build/bin/ttmlir-opt --ttir-load-system-desc="path=/path/to/system_desc.ttsys" --ttir-to-ttnn-backend-pipeline test/ttmlir/Dialect/TTNN/simple_subtract.mlir -o ttnn.mlir -or (pip path directly into ttir-to-ttnn-backend-pipeline) +or (pipe path directly into ttir-to-ttnn-backend-pipeline) ./build/bin/ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=/path/to/system_desc.ttsys" test/ttmlir/Dialect/TTNN/simple_subtract.mlir -o ttnn.mlir ``` 5. Use ttmlir-translate tool in compiler to generate the flatbuffer executable. See the [ttmlir-translate](./ttmlir-translate.md) documentation for more information on how to generate flatbuffer files. diff --git a/python/test_infra/ttir_builder.py b/python/test_infra/ttir_builder.py index 6314137aa..a6a302d4a 100644 --- a/python/test_infra/ttir_builder.py +++ b/python/test_infra/ttir_builder.py @@ -304,9 +304,7 @@ def eltwise_proxy( inputs, [output], self._get_operand_constraint_attr(3), - loc=Location.file( - filename=stack[0].filename, line=stack[0].lineno, col=id - ), + loc=Location.name(str(id)), ) goldens = [] diff --git a/runtime/include/tt/runtime/detail/debug.h b/runtime/include/tt/runtime/detail/debug.h index c5d84c4d9..ed049df2d 100644 --- a/runtime/include/tt/runtime/detail/debug.h +++ b/runtime/include/tt/runtime/detail/debug.h @@ -5,8 +5,12 @@ #ifndef TT_RUNTIME_DETAIL_DEBUG_H #define TT_RUNTIME_DETAIL_DEBUG_H +#include +#include #include +#include "tt/runtime/types.h" + namespace tt::runtime::debug { struct Env { @@ -41,6 +45,46 @@ inline std::ostream &operator<<(std::ostream &os, Env const &env) { return os; } +struct Hooks { +#if defined(TT_RUNTIME_DEBUG) && TT_RUNTIME_DEBUG == 1 + static Hooks const & + get(std::optional> + operatorCallback = std::nullopt); +#else + constexpr static Hooks get() { return Hooks(); } +#endif + + std::optional> + getOperatorCallback() const { +#if defined(TT_RUNTIME_DEBUG) && TT_RUNTIME_DEBUG == 1 + return operatorCallback; +#else + return std::nullopt; +#endif + } + +private: +#if defined(TT_RUNTIME_DEBUG) && TT_RUNTIME_DEBUG == 1 + Hooks(std::optional> + operatorCallback) + : operatorCallback(operatorCallback) {} + + std::optional> + operatorCallback; +#else + constexpr Hooks() = default; +#endif +}; + +inline std::ostream &operator<<(std::ostream &os, Hooks const &hooks) { + os << "debug::Hooks{\n" + << "\t" + << "operatorCallback: " << static_cast(hooks.getOperatorCallback()) + << ",\n" + << "}"; + return os; +} + } // namespace tt::runtime::debug #endif // TT_RUNTIME_DETAIL_DEBUG_H diff --git a/runtime/include/tt/runtime/detail/ttmetal.h b/runtime/include/tt/runtime/detail/ttmetal.h index 01bb9c86e..dad9afe09 100644 --- a/runtime/include/tt/runtime/detail/ttmetal.h +++ b/runtime/include/tt/runtime/detail/ttmetal.h @@ -79,6 +79,13 @@ Event submit(Device device, Binary executable, std::uint32_t programIndex, void wait(Event event); +std::string getOpDebugString(OpContext opContextHandle); + +Tensor getOpOutputTensor(OpContext opContextHandle, + CallbackContext programContextHandle); + +std::vector getTensorData(Tensor tensor); + using InputBuffer = std::tuple, std::shared_ptr<::tt::tt_metal::Event>>; diff --git a/runtime/include/tt/runtime/detail/ttnn.h b/runtime/include/tt/runtime/detail/ttnn.h index c027d0158..627b7ee56 100644 --- a/runtime/include/tt/runtime/detail/ttnn.h +++ b/runtime/include/tt/runtime/detail/ttnn.h @@ -120,8 +120,15 @@ Event submit(Device device, Binary executable, std::uint32_t programIndex, void wait(Event event); -void runProgram(::ttnn::MeshDevice &meshDevice, - ::tt::target::ttnn::Program const *program, +std::string getOpDebugString(OpContext opContextHandle); + +Tensor getOpOutputTensor(OpContext opContextHandle, + CallbackContext programContextHandle); + +std::vector getTensorData(Tensor tensor); + +void runProgram(::ttnn::MeshDevice &meshDevice, Binary &executableHandle, + std::uint32_t programIndex, std::vector<::ttnn::Tensor *> const &inputs, std::vector<::ttnn::Tensor *> const &outputs); diff --git a/runtime/include/tt/runtime/runtime.h b/runtime/include/tt/runtime/runtime.h index a070f2f0f..1dc721f66 100644 --- a/runtime/include/tt/runtime/runtime.h +++ b/runtime/include/tt/runtime/runtime.h @@ -69,6 +69,13 @@ Event submit(Device device, Binary executable, std::uint32_t programIndex, void wait(Event event); +std::string getOpDebugString(OpContext opContextHandle); + +Tensor getOpOutputTensor(OpContext opContextHandle, + CallbackContext programContextHandle); + +std::vector getTensorData(Tensor tensor); + } // namespace tt::runtime #endif diff --git a/runtime/include/tt/runtime/types.h b/runtime/include/tt/runtime/types.h index 330fb9196..8fd641195 100644 --- a/runtime/include/tt/runtime/types.h +++ b/runtime/include/tt/runtime/types.h @@ -12,6 +12,7 @@ #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wcovered-switch-default" +#include "ttmlir/Target/Common/debug_info_generated.h" #include "ttmlir/Target/Common/system_desc_generated.h" #include "ttmlir/Target/Common/types_generated.h" #pragma clang diagnostic pop @@ -108,6 +109,7 @@ struct Binary : public Flatbuffer { std::vector getProgramInputs(std::uint32_t programIndex) const; std::vector getProgramOutputs(std::uint32_t programIndex) const; + const ::tt::target::GoldenTensor *getDebugInfoGolden(std::string &loc) const; }; struct Device : public detail::RuntimeCheckedObjectImpl { @@ -120,11 +122,20 @@ struct Event : public detail::RuntimeCheckedObjectImpl { struct Tensor : public detail::RuntimeCheckedObjectImpl { std::shared_ptr data; + Tensor(std::shared_ptr handle, std::shared_ptr data, DeviceRuntime runtime) : detail::RuntimeCheckedObjectImpl(handle, runtime), data(data) {} }; +struct CallbackContext : public detail::RuntimeCheckedObjectImpl { + using detail::RuntimeCheckedObjectImpl::RuntimeCheckedObjectImpl; +}; + +struct OpContext : public detail::RuntimeCheckedObjectImpl { + using detail::RuntimeCheckedObjectImpl::RuntimeCheckedObjectImpl; +}; + } // namespace tt::runtime #endif diff --git a/runtime/lib/binary.cpp b/runtime/lib/binary.cpp index 9e586f8fa..92be39d27 100644 --- a/runtime/lib/binary.cpp +++ b/runtime/lib/binary.cpp @@ -103,6 +103,23 @@ std::vector getProgramOutputs(Flatbuffer binary, return outputs; } +const ::tt::target::GoldenTensor *getDebugInfoGolden(Flatbuffer binary, + std::string &loc) { + auto const *programs = getBinary(binary)->programs(); + for (auto const *program : *programs) { + for (const ::tt::target::GoldenKV *goldenKV : + *program->debug_info()->golden_info()->golden_map()) { + if (std::string(goldenKV->key()->c_str()) == loc) { + return goldenKV->value(); + ; + } + } + } + + LOG_WARNING("Golden information not found"); + return nullptr; +} + } // namespace ttnn namespace metal { @@ -177,6 +194,12 @@ std::vector getProgramOutputs(Flatbuffer binary, return outputs; } +const ::tt::target::GoldenTensor *getDebugInfoGolden(Flatbuffer binary, + std::string &loc) { + LOG_WARNING("Debug golden information not enabled for metal yet!"); + return nullptr; +} + } // namespace metal namespace system_desc { @@ -344,4 +367,20 @@ Binary::getProgramOutputs(std::uint32_t programIndex) const { throw std::runtime_error("Unsupported binary format"); } +const ::tt::target::GoldenTensor * +Binary::getDebugInfoGolden(std::string &loc) const { + if (::tt::target::ttnn::SizePrefixedTTNNBinaryBufferHasIdentifier( + handle.get())) { + return ttnn::getDebugInfoGolden(*this, loc); + } + + if (::tt::target::metal::SizePrefixedTTMetalBinaryBufferHasIdentifier( + handle.get())) { + return metal::getDebugInfoGolden(*this, loc); + } + + throw std::runtime_error( + "Unsupported binary format for obtaining golden information"); +} + } // namespace tt::runtime diff --git a/runtime/lib/common/debug.cpp b/runtime/lib/common/debug.cpp index f07517764..34a274ca8 100644 --- a/runtime/lib/common/debug.cpp +++ b/runtime/lib/common/debug.cpp @@ -13,6 +13,17 @@ Env const &Env::get(bool loadKernelsFromDisk, bool enableAsyncTTNN) { return config; } +#if defined(TT_RUNTIME_DEBUG) && TT_RUNTIME_DEBUG == 1 +Hooks const &Hooks::get( + std::optional> + operatorCallback) { + static Hooks config(operatorCallback); + return config; +} +#else +Hooks get() { return Hooks(); } +#endif + } // namespace tt::runtime::debug #endif diff --git a/runtime/lib/runtime.cpp b/runtime/lib/runtime.cpp index 8b0e79daa..586b8394e 100644 --- a/runtime/lib/runtime.cpp +++ b/runtime/lib/runtime.cpp @@ -10,6 +10,7 @@ #if defined(TT_RUNTIME_ENABLE_TTNN) #include "tt/runtime/detail/ttnn.h" +#include "tt/runtime/ttnn/types.h" #endif #if defined(TT_RUNTIME_ENABLE_TTMETAL) @@ -245,4 +246,53 @@ void wait(Event event) { throw std::runtime_error("runtime is not enabled"); } +std::string getOpDebugString(OpContext opContextHandle) { +#if defined(TT_RUNTIME_ENABLE_TTNN) + if (getCurrentRuntime() == DeviceRuntime::TTNN) { + return ::tt::runtime::ttnn::getOpDebugString(opContextHandle); + } +#endif + +#if defined(TT_RUNTIME_ENABLE_TTMETAL) + if (getCurrentRuntime() == DeviceRuntime::TTMetal) { + return ::tt::runtime::ttmetal::getOpDebugString(opContextHandle); + } +#endif + throw std::runtime_error("runtime is not enabled"); +} + +Tensor getOpOutputTensor(OpContext opContextHandle, + CallbackContext programContextHandle) { +#if defined(TT_RUNTIME_ENABLE_TTNN) + if (getCurrentRuntime() == DeviceRuntime::TTNN) { + return ::tt::runtime::ttnn::getOpOutputTensor(opContextHandle, + programContextHandle); + } +#endif + +#if defined(TT_RUNTIME_ENABLE_TTMETAL) + if (getCurrentRuntime() == DeviceRuntime::TTMetal) { + return ::tt::runtime::ttmetal::getOpOutputTensor(opContextHandle, + programContextHandle); + } +#endif + throw std::runtime_error("runtime is not enabled"); +} + +std::vector getTensorData(Tensor tensor) { +#if defined(TT_RUNTIME_ENABLE_TTNN) + if (getCurrentRuntime() == DeviceRuntime::TTNN) { + return ::tt::runtime::ttnn::getTensorData(tensor); + } +#endif + +#if defined(TT_RUNTIME_ENABLE_TTMETAL) + if (getCurrentRuntime() == DeviceRuntime::TTMetal) { + return ::tt::runtime::ttmetal::getTensorData(tensor); + } +#endif + + throw std::runtime_error("runtime is not enabled"); +} + } // namespace tt::runtime diff --git a/runtime/lib/ttmetal/runtime.cpp b/runtime/lib/ttmetal/runtime.cpp index 931349429..ab343554e 100644 --- a/runtime/lib/ttmetal/runtime.cpp +++ b/runtime/lib/ttmetal/runtime.cpp @@ -29,6 +29,10 @@ static ::tt::target::metal::TTMetalBinary const *getBinary(Flatbuffer binary) { return ::tt::target::metal::GetSizePrefixedTTMetalBinary(binary.handle.get()); } +static Tensor createNullTensor() { + return Tensor(nullptr, nullptr, DeviceRuntime::TTMetal); +} + Tensor createTensor(std::shared_ptr data, std::vector const &shape, std::vector const &stride, @@ -252,4 +256,23 @@ void wait(Event event) { } } +std::string getOpDebugString(OpContext opContextHandle) { + // Not implemented + LOG_WARNING("obtaining op debug string for metal runtime not implemented"); + return ""; +} + +Tensor getOpOutputTensor(OpContext opContextHandle, + CallbackContext programContextHandle) { + // Not implemented + LOG_WARNING("obtaining op output tensor for metal runtime not implemented"); + return createNullTensor(); +} + +std::vector getTensorData(Tensor tensor) { + // Not implemented + LOG_WARNING("obtaining tensor data for metal runtime not implemented"); + return {}; +} + } // namespace tt::runtime::ttmetal diff --git a/runtime/lib/ttnn/include/tt/runtime/ttnn/types.h b/runtime/lib/ttnn/include/tt/runtime/ttnn/types.h index fb985744d..5cd08c7ed 100644 --- a/runtime/lib/ttnn/include/tt/runtime/ttnn/types.h +++ b/runtime/lib/ttnn/include/tt/runtime/ttnn/types.h @@ -46,6 +46,11 @@ class ProgramTensorPool { return *liveTensors.at(globalId); } + const ::ttnn::Tensor &at(std::uint32_t globalId) const { + assert(liveTensors.contains(globalId)); + return *liveTensors.at(globalId); + } + size_t erase(std::uint32_t globalId) { assert(liveTensors.contains(globalId) && intermedTensors.contains(globalId)); @@ -161,6 +166,7 @@ class ProgramContext { // Tensor Pool Operations // ProgramTensorPool &getTensorPool() { return tensorPool; } + const ProgramTensorPool &getTensorPool() const { return tensorPool; } private: ProgramTensorPool tensorPool; diff --git a/runtime/lib/ttnn/program.cpp b/runtime/lib/ttnn/program.cpp index 996f72552..8cfa01389 100644 --- a/runtime/lib/ttnn/program.cpp +++ b/runtime/lib/ttnn/program.cpp @@ -26,33 +26,62 @@ #include "operations/normalization/softmax.h" #include "operations/pool/maxpool2d.h" #include "operations/reduction/reduction.h" +#include "tt/runtime/detail/debug.h" #include "tt/runtime/detail/logger.h" #include "tt/runtime/ttnn/types.h" +#include "tt/runtime/utils.h" #include "ttmlir/Target/TTNN/program_generated.h" namespace tt::runtime::ttnn { using LogType = ::tt::runtime::logger::LogType; +static ::tt::target::ttnn::TTNNBinary const *getBinary(Flatbuffer binary) { + bool isTTNN = ::tt::target::ttnn::SizePrefixedTTNNBinaryBufferHasIdentifier( + binary.handle.get()); + if (not isTTNN) { + throw std::runtime_error("Unsupported binary format"); + } + return ::tt::target::ttnn::GetSizePrefixedTTNNBinary(binary.handle.get()); +} + class ProgramExecutor { public: - ProgramExecutor(const TensorMap &liveTensors, + ProgramExecutor(Binary &executableHandle, const TensorMap &liveTensors, const std::unordered_set &programInputs, const std::unordered_set &programOutputs, ::ttnn::MeshDevice *meshDevice) - : context(ProgramContext(liveTensors, programInputs, programOutputs, + : executableHandle(executableHandle), + context(ProgramContext(liveTensors, programInputs, programOutputs, meshDevice)) {} + void runCallback(Binary &executableHandle, + const ::tt::target::ttnn::Operation *opContext, + ProgramContext *programContext) { + if (auto callback = debug::Hooks::get().getOperatorCallback(); callback) { + std::shared_ptr programContextPtr = + ::tt::runtime::utils::unsafe_borrow_shared(programContext); + std::shared_ptr opContextPtr = + ::tt::runtime::utils::unsafe_borrow_shared( + const_cast<::tt::target::ttnn::Operation *>(opContext)); + (*callback)(executableHandle, + CallbackContext(programContextPtr, DeviceRuntime::TTNN), + OpContext(opContextPtr, DeviceRuntime::TTNN)); + } + } + void execute(const ::tt::target::ttnn::Program *program) { for (const ::tt::target::ttnn::Operation *op : *program->operations()) { LOG_DEBUG(LogType::LogRuntimeTTNN, "Executing operation: ", op->debug_info()->c_str()); runOperation(op); + runCallback(executableHandle, op, &context); } } ProgramContext &getContext() { return context; } private: + Binary executableHandle; ProgramContext context; void runOperation(const ::tt::target::ttnn::Operation *op); void runEltwiseOperation(const ::tt::target::ttnn::EltwiseOp *op); @@ -117,8 +146,7 @@ void ProgramExecutor::runOperation(const ::tt::target::ttnn::Operation *op) { return operations::creation::run(op->type_as_FullOp(), context); } case ::tt::target::ttnn::OpType::EltwiseOp: { - const ::tt::target::ttnn::EltwiseOp *eltwiseOp = op->type_as_EltwiseOp(); - return runEltwiseOperation(eltwiseOp); + return runEltwiseOperation(op->type_as_EltwiseOp()); } // ANCHOR: adding_an_op_matmul_runtime_program case ::tt::target::ttnn::OpType::MatmulOp: { @@ -183,10 +211,13 @@ static bool handleNopProgram(::tt::target::ttnn::Program const *program, return isNop; } -void runProgram(::ttnn::MeshDevice &meshDevice, - ::tt::target::ttnn::Program const *program, +void runProgram(::ttnn::MeshDevice &meshDevice, Binary &executableHandle, + std::uint32_t programIndex, std::vector<::ttnn::Tensor *> const &inputs, std::vector<::ttnn::Tensor *> const &outputs) { + ::tt::target::ttnn::TTNNBinary const &fbb = *getBinary(executableHandle); + ::tt::target::ttnn::Program const *program = + fbb.programs()->Get(programIndex); if (handleNopProgram(program, inputs, outputs)) { return; } @@ -212,8 +243,8 @@ void runProgram(::ttnn::MeshDevice &meshDevice, LOG_ASSERT(inserted, "Duplicate output tensor"); programOutputs.emplace(output->global_id()); } - ProgramExecutor executor(liveTensors, programInputs, programOutputs, - &meshDevice); + ProgramExecutor executor(executableHandle, liveTensors, programInputs, + programOutputs, &meshDevice); executor.execute(program); } diff --git a/runtime/lib/ttnn/runtime.cpp b/runtime/lib/ttnn/runtime.cpp index 9811da28d..86fd2d25c 100644 --- a/runtime/lib/ttnn/runtime.cpp +++ b/runtime/lib/ttnn/runtime.cpp @@ -5,6 +5,7 @@ #include "tt/runtime/detail/debug.h" #include "tt/runtime/detail/logger.h" #include "tt/runtime/detail/ttnn.h" +#include "tt/runtime/ttnn/types.h" #include "tt/runtime/ttnn/utils.h" #include "tt/runtime/utils.h" #include "ttmlir/Target/TTNN/Target.h" @@ -71,6 +72,10 @@ createOwnedTensor(std::shared_ptr data, ::ttnn::Layout::ROW_MAJOR); } +static Tensor createNullTensor() { + return Tensor(nullptr, nullptr, DeviceRuntime::TTNN); +} + Tensor createTensor(std::shared_ptr data, std::vector const &shape, std::vector const &stride, @@ -161,33 +166,27 @@ void deallocateBuffers(Device deviceHandle) { } } -static ::tt::target::ttnn::TTNNBinary const *getBinary(Flatbuffer binary) { - bool isTTNN = ::tt::target::ttnn::SizePrefixedTTNNBinaryBufferHasIdentifier( - binary.handle.get()); - LOG_ASSERT(isTTNN, "Unsupported binary format"); - return ::tt::target::ttnn::GetSizePrefixedTTNNBinary(binary.handle.get()); -} - Event submit(Device deviceHandle, Binary executableHandle, std::uint32_t programIndex, std::vector const &inputHandles, std::vector const &outputHandles) { ::ttnn::MeshDevice &meshDevice = deviceHandle.as<::ttnn::MeshDevice>(DeviceRuntime::TTNN); - ::tt::target::ttnn::TTNNBinary const &fbb = *getBinary(executableHandle); std::vector<::ttnn::Tensor *> inputs; inputs.reserve(inputHandles.size()); for (auto &input : inputHandles) { LOG_ASSERT(input.matchesRuntime(DeviceRuntime::TTNN)); inputs.push_back(static_cast<::ttnn::Tensor *>(input.handle.get())); } + std::vector<::ttnn::Tensor *> outputs; outputs.reserve(outputHandles.size()); for (auto &output : outputHandles) { LOG_ASSERT(output.matchesRuntime(DeviceRuntime::TTNN)); outputs.push_back(static_cast<::ttnn::Tensor *>(output.handle.get())); } - tt::runtime::ttnn::runProgram(meshDevice, fbb.programs()->Get(programIndex), + + tt::runtime::ttnn::runProgram(meshDevice, executableHandle, programIndex, inputs, outputs); return Event(nullptr, DeviceRuntime::TTNN); } @@ -197,4 +196,149 @@ void wait(Event event) { LOG_ASSERT(event.matchesRuntime(DeviceRuntime::TTNN)); } +std::string getOpDebugString(OpContext opContextHandle) { + auto const &opContext = + opContextHandle.as<::tt::target::ttnn::Operation>(DeviceRuntime::TTNN); + return std::string(opContext.debug_info()->c_str()); +} + +Tensor getOpOutputTensor(OpContext opContextHandle, + CallbackContext programContextHandle) { + auto const &programContext = + programContextHandle.as( + DeviceRuntime::TTNN); + auto const &opContext = + opContextHandle.as<::tt::target::ttnn::Operation>(DeviceRuntime::TTNN); + const ttnn::ProgramTensorPool &tensorPool = programContext.getTensorPool(); + std::int32_t globalId{-1}; + const ::ttnn::Tensor *outPtr = nullptr; + + switch (opContext.type_type()) { + case ::tt::target::ttnn::OpType::GetDeviceOp: { + globalId = opContext.type_as_GetDeviceOp()->out()->global_id(); + break; + } + case ::tt::target::ttnn::OpType::ToMemoryConfigOp: { + globalId = opContext.type_as_ToMemoryConfigOp()->out()->global_id(); + break; + } + case ::tt::target::ttnn::OpType::ToLayoutOp: { + globalId = opContext.type_as_ToLayoutOp()->out()->global_id(); + break; + } + case ::tt::target::ttnn::OpType::TypecastOp: { + globalId = opContext.type_as_TypecastOp()->out()->global_id(); + break; + } + case ::tt::target::ttnn::OpType::ToDeviceOp: { + globalId = opContext.type_as_ToDeviceOp()->out()->global_id(); + break; + } + case ::tt::target::ttnn::OpType::FromDeviceOp: { + globalId = opContext.type_as_FromDeviceOp()->out()->global_id(); + break; + } + case ::tt::target::ttnn::OpType::EmptyOp: { + globalId = opContext.type_as_EmptyOp()->out()->global_id(); + break; + } + case ::tt::target::ttnn::OpType::FullOp: { + globalId = opContext.type_as_FullOp()->out()->global_id(); + break; + } + case ::tt::target::ttnn::OpType::EltwiseOp: { + globalId = opContext.type_as_EltwiseOp()->out()->global_id(); + break; + } + case ::tt::target::ttnn::OpType::MatmulOp: { + globalId = opContext.type_as_MatmulOp()->out()->global_id(); + break; + } + case ::tt::target::ttnn::OpType::ReductionOp: { + globalId = opContext.type_as_ReductionOp()->out()->global_id(); + break; + } + case ::tt::target::ttnn::OpType::EmbeddingOp: { + globalId = opContext.type_as_EmbeddingOp()->out()->global_id(); + break; + } + case ::tt::target::ttnn::OpType::SoftmaxOp: { + globalId = opContext.type_as_SoftmaxOp()->out()->global_id(); + break; + } + case ::tt::target::ttnn::OpType::TransposeOp: { + globalId = opContext.type_as_TransposeOp()->out()->global_id(); + break; + } + case ::tt::target::ttnn::OpType::ConcatOp: { + globalId = opContext.type_as_ConcatOp()->out()->global_id(); + break; + } + case ::tt::target::ttnn::OpType::ReshapeOp: { + globalId = opContext.type_as_ReshapeOp()->out()->global_id(); + break; + } + case ::tt::target::ttnn::OpType::SliceOp: { + globalId = opContext.type_as_SliceOp()->out()->global_id(); + break; + } + case ::tt::target::ttnn::OpType::Conv2dOp: { + globalId = opContext.type_as_Conv2dOp()->out()->global_id(); + break; + } + case ::tt::target::ttnn::OpType::MaxPool2dOp: { + globalId = opContext.type_as_MaxPool2dOp()->out()->global_id(); + break; + } + case ::tt::target::ttnn::OpType::AllGatherOp: { + globalId = opContext.type_as_AllGatherOp()->out()->global_id(); + break; + } + case ::tt::target::ttnn::OpType::DeallocateOp: { + LOG_WARNING("getting output tensor for DeallocateOp is not supported"); + return createNullTensor(); + } + default: { + throw std::runtime_error("Unsupported operation type"); + } + } + + if (tensorPool.contains(globalId)) { + outPtr = &tensorPool.at(globalId); + } else { + LOG_WARNING("Output tensor not found in tensor pool"); + return createNullTensor(); + } + + ::ttnn::Tensor hostTensor = ::ttnn::from_device(*outPtr); + ::ttnn::Tensor outCopy = + ::ttnn::to_layout(hostTensor, ::ttnn::ROW_MAJOR_LAYOUT, std::nullopt, + std::nullopt, static_cast<::ttnn::Device *>(nullptr)); + + void *src = ::tt::tt_metal::get_raw_host_data_ptr(outCopy); + std::uint32_t outCopySize = outCopy.volume() * outCopy.element_size(); + std::shared_ptr data = ::tt::runtime::utils::malloc_shared(outCopySize); + std::memcpy(data.get(), src, outCopySize); + + auto tensor = std::make_shared<::ttnn::Tensor>( + ttnn::createStorage(data.get(), outCopy.volume(), + ::tt::target::DataType::Float32), + outCopy.shape().value, ::ttnn::DataType::FLOAT32, + ::ttnn::Layout::ROW_MAJOR); + + return Tensor(std::static_pointer_cast(tensor), data, + DeviceRuntime::TTNN); +} + +std::vector getTensorData(Tensor tensor) { + ::ttnn::Tensor *nnTensor = static_cast<::ttnn::Tensor *>(tensor.handle.get()); + if (nnTensor == nullptr) { + return {}; + } + + void *dataPtr = ::tt::tt_metal::get_raw_host_data_ptr(*nnTensor); + return std::vector(static_cast(dataPtr), + static_cast(dataPtr) + nnTensor->volume()); +} + } // namespace tt::runtime::ttnn diff --git a/runtime/tools/python/ttrt/binary/module.cpp b/runtime/tools/python/ttrt/binary/module.cpp index d8f1d004c..ca1f6ad8f 100644 --- a/runtime/tools/python/ttrt/binary/module.cpp +++ b/runtime/tools/python/ttrt/binary/module.cpp @@ -2,9 +2,13 @@ // // SPDX-License-Identifier: Apache-2.0 +#include + #include "tt/runtime/types.h" +#include #include +#include namespace py = pybind11; @@ -27,7 +31,23 @@ PYBIND11_MODULE(_C, m) { .def_property_readonly("file_identifier", &tt::runtime::Binary::getFileIdentifier) .def("as_json", &tt::runtime::Binary::asJson) - .def("store", &tt::runtime::Binary::store); + .def("store", &tt::runtime::Binary::store) + .def("get_debug_info_golden", [](tt::runtime::Binary &binary, + std::string &loc) { + const ::tt::target::GoldenTensor *goldenTensor = + binary.getDebugInfoGolden(loc); + if (goldenTensor == nullptr) { + return std::vector(); + } + + int totalDataSize = std::accumulate((*goldenTensor->shape()).begin(), + (*goldenTensor->shape()).end(), 1, + std::multiplies()); + std::vector dataVec(totalDataSize); + std::memcpy(dataVec.data(), goldenTensor->data()->data(), + totalDataSize * sizeof(float)); + return dataVec; + }); py::class_(m, "SystemDesc") .def_property_readonly("version", &tt::runtime::SystemDesc::getVersion) .def_property_readonly("ttmlir_git_hash", diff --git a/runtime/tools/python/ttrt/common/golden.py b/runtime/tools/python/ttrt/common/golden.py new file mode 100644 index 000000000..59adc97a1 --- /dev/null +++ b/runtime/tools/python/ttrt/common/golden.py @@ -0,0 +1,158 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +import os +import json +import importlib.machinery +import sys +import signal +import os +import io +import subprocess +import time +import socket +from pkg_resources import get_distribution +import shutil +import atexit +import re + +from ttrt.common.util import * + + +def get_atol_rtol_pcc(golden, calculated): + import numpy as np + import torch + + # Calculate atol and rtol + cal_atol = torch.max(torch.abs(golden - calculated)).item() + cal_rtol = torch.max(torch.abs(golden - calculated) / torch.abs(calculated)).item() + + # Calculate PCC + def get_pcc(golden, calculated): + # Both tensors are nan + if torch.all(torch.isnan(golden)) and torch.all(torch.isnan(calculated)): + print("Both tensors are 'nan'") + return 1.0 + # Test if either is completely zero + elif torch.any(golden.bool()) != torch.any(calculated.bool()): + return 0.0 + # One tensor is all nan, the other is not + elif torch.all(torch.isnan(golden)) or torch.all(torch.isnan(calculated)): + print("One tensor is all nan, the other is not.") + return 0.0 + else: + # For now, mask all infs and nans so that we check the rest... TODO + golden = golden.clone() + golden[ + torch.logical_or( + torch.isnan(golden), + torch.logical_or(torch.isinf(golden), torch.isneginf(golden)), + ) + ] = 0 + calculated = calculated.clone() + calculated[ + torch.logical_or( + torch.isnan(calculated), + torch.logical_or( + torch.isinf(calculated), torch.isneginf(calculated) + ), + ) + ] = 0 + + if torch.equal(golden, calculated): + return 1.0 + + if golden.dtype == torch.bfloat16: + golden = golden.type(torch.float32) + calculated = calculated.type(torch.float32) + + # Single element case + if golden.numel() == 1: + return float(torch.equal(golden, calculated)) + + # If both tensors are contant + if torch.max(golden) == torch.min(golden) and torch.max( + calculated + ) == torch.min(calculated): + return torch.isclose(torch.max(golden), torch.max(calculated)).item() + + cal_pcc = np.ma.corrcoef( + np.ma.masked_invalid(torch.squeeze(golden).detach().numpy()).flatten(), + np.ma.masked_invalid( + torch.squeeze(calculated).detach().numpy() + ).flatten(), + ) + # Remove correlation coefficient with self (typically always 1.0) + mask = np.ones(cal_pcc.shape, dtype=bool) + np.fill_diagonal(mask, 0) + cal_pcc = np.min(cal_pcc[mask]) + + if isinstance(cal_pcc, np.ma.core.MaskedConstant): + return 1.0 + + return cal_pcc + + cal_pcc = get_pcc(golden, calculated) + + return ( + cal_atol, + cal_rtol, + cal_pcc, + f"Max ATOL Delta: {cal_atol}, Max RTOL Delta: {cal_rtol}, PCC: {cal_pcc}", + ) + + +def golden(binary, programContext, opContext): + import torch + import ttrt.runtime + import ttrt.binary + + print("-----------executing golden comparision-----------") + + try: + op_debug_str = ttrt.runtime.get_op_debug_str(opContext) + + # find matching golden tensor based on loc in op debug string + match = re.search(r"loc\(([^)]+)\)", op_debug_str) + + if not match: + print(f"debug_str={op_debug_str}") + print("No location found in debug string - skipping golden comparison") + return + + loc = match.group(1).replace('"', "") + print(f"found location={loc}") + + op_golden_tensor = binary.get_debug_info_golden(loc) + op_output_tensor = ttrt.runtime.get_op_output_tensor(opContext, programContext) + + if len(op_golden_tensor) == 0: + print("Golden tensor is empty - skipping golden comparison") + return + + if len(op_output_tensor) == 0: + print("Output tensor is empty - skipping golden comparison") + return + + if len(op_golden_tensor) != len(op_output_tensor): + print( + "Golden and output tensor sizes do not match - skipping golden comparison" + ) + return + + golden_tensor_torch = torch.tensor( + op_golden_tensor, dtype=torch.float32 + ).flatten() + output_tensor_torch = torch.tensor( + op_output_tensor, dtype=torch.float32 + ).flatten() + + _, _, cal_pcc, output_str = get_atol_rtol_pcc( + golden_tensor_torch, output_tensor_torch + ) + + print(f"PCC={cal_pcc}") + print(output_str) + finally: + print("-----------finished executing golden comparision-----------") diff --git a/runtime/tools/python/ttrt/common/run.py b/runtime/tools/python/ttrt/common/run.py index 24796717a..c2ae10ac9 100644 --- a/runtime/tools/python/ttrt/common/run.py +++ b/runtime/tools/python/ttrt/common/run.py @@ -18,6 +18,7 @@ from ttrt.common.util import * from ttrt.common.query import Query +from ttrt.common.golden import golden class Run: @@ -165,6 +166,13 @@ def initialize_api(): choices=None, help="test file to save results to", ) + Run.register_arg( + name="--golden", + type=bool, + default=True, + choices=[True, False], + help="run golden comparison for intermediate and output tensors", + ) Run.register_arg( name="binary", type=str, @@ -354,6 +362,9 @@ def _execute(binaries): self.logging.warning(f"no binaries found to run - returning early") return + if self["--golden"]: + callback_env = ttrt.runtime.DebugHooks.get(golden) + debug_env = ttrt.runtime.DebugEnv.get( self["--load-kernels-from-disk"], self["--enable-async-ttnn"] ) @@ -389,8 +400,21 @@ def _execute(binaries): ) program = bin.get_program(program_index) + golden_inputs = [] + + for i in range(len(program.program["inputs"])): + golden_tensor = bin.fbb.get_debug_info_golden( + f"input_{i}" + ) + + if len(golden_tensor) != 0: + golden_inputs.append( + torch.tensor(golden_tensor, dtype=torch.float32) + ) + program.populate_inputs( - Run.TorchInitializer.get_initilizer(self["--init"]) + Run.TorchInitializer.get_initilizer(self["--init"]), + golden_inputs, ) program.populate_outputs( Run.TorchInitializer.get_initilizer("zeros") diff --git a/runtime/tools/python/ttrt/common/util.py b/runtime/tools/python/ttrt/common/util.py index 8f38c6bc3..370643e7d 100644 --- a/runtime/tools/python/ttrt/common/util.py +++ b/runtime/tools/python/ttrt/common/util.py @@ -522,22 +522,6 @@ def get_ttsys_file_extension(): return Flatbuffer.ttsys_file_extension -class Golden: - def __init__(self, tensor_id, tensor_shape, tensor_stride, tensor_data): - self.tensor_id = tensor_id - self.tensor_shape = tensor_shape - self.tensor_stride = tensor_stride - self.tensor_data = tensor_data - - def get_golden_tensor(self): - tensor_byte_data = bytes(self.tensor_data) - float_data = np.frombuffer(tensor_byte_data, dtype=np.float32) - golden_tensor = torch.tensor(float_data, dtype=torch.float32).reshape( - self.tensor_shape - ) - return golden_tensor - - class Binary(Flatbuffer): def __init__(self, logger, file_manager, file_path, capsule=None): super().__init__(logger, file_manager, file_path, capsule=capsule) @@ -557,21 +541,6 @@ def __init__(self, logger, file_manager, file_path, capsule=None): program = Binary.Program(i, self.fbb_dict["programs"][i]) self.programs.append(program) - # populate golden tensors if they exist - if "debug_info" in self.fbb_dict["programs"][i]: - if "golden_info" in self.fbb_dict["programs"][i]["debug_info"]: - golden_info_list = self.fbb_dict["programs"][i]["debug_info"][ - "golden_info" - ]["golden_map"] - - for golden_tensor_dict in golden_info_list: - Golden( - golden_tensor_dict["key"], - golden_tensor_dict["value"]["shape"], - golden_tensor_dict["value"]["stride"], - golden_tensor_dict["value"]["data"], - ) - def check_system_desc(self, query): import ttrt.binary @@ -617,15 +586,20 @@ def __init__(self, index, program): self.input_tensors = [] self.output_tensors = [] - def populate_inputs(self, init_fn): - for i in self.program["inputs"]: - torch_tensor = init_fn( - i["desc"]["shape"], - dtype=Binary.Program.from_data_type( - i["desc"]["layout"]["memory_desc"]["data_type"] - ), - ) - self.input_tensors.append(torch_tensor) + def populate_inputs(self, init_fn, golden_inputs=[]): + if len(golden_inputs) > 0: + assert len(golden_inputs) == len(self.program["inputs"]) + for golden_input in golden_inputs: + self.input_tensors.append(golden_input) + else: + for i in self.program["inputs"]: + torch_tensor = init_fn( + i["desc"]["shape"], + dtype=Binary.Program.from_data_type( + i["desc"]["layout"]["memory_desc"]["data_type"] + ), + ) + self.input_tensors.append(torch_tensor) def populate_outputs(self, init_fn): for i in self.program["outputs"]: diff --git a/runtime/tools/python/ttrt/runtime/__init__.py b/runtime/tools/python/ttrt/runtime/__init__.py index 1a616db24..642b0401f 100644 --- a/runtime/tools/python/ttrt/runtime/__init__.py +++ b/runtime/tools/python/ttrt/runtime/__init__.py @@ -10,6 +10,7 @@ DataType, DeviceRuntime, DebugEnv, + DebugHooks, get_current_runtime, set_compatible_runtime, get_current_system_desc, @@ -19,6 +20,8 @@ create_tensor, create_multi_device_tensor, wait, + get_op_output_tensor, + get_op_debug_str, WorkaroundEnv, ) except ModuleNotFoundError: diff --git a/runtime/tools/python/ttrt/runtime/module.cpp b/runtime/tools/python/ttrt/runtime/module.cpp index 4f528c02f..dfc4a6820 100644 --- a/runtime/tools/python/ttrt/runtime/module.cpp +++ b/runtime/tools/python/ttrt/runtime/module.cpp @@ -9,6 +9,7 @@ #include "tt/runtime/runtime.h" #include "tt/runtime/utils.h" +#include #include #include @@ -21,6 +22,8 @@ PYBIND11_MODULE(_C, m) { .def("deallocate_buffers", &tt::runtime::detail::deallocateBuffers); py::class_(m, "Event"); py::class_(m, "Tensor"); + py::class_(m, "OpContext"); + py::class_(m, "CallbackContext"); py::enum_<::tt::target::DataType>(m, "DataType") .value("Float32", ::tt::target::DataType::Float32) .value("Float16", ::tt::target::DataType::Float16) @@ -38,7 +41,6 @@ PYBIND11_MODULE(_C, m) { .value("Disabled", ::tt::runtime::DeviceRuntime::Disabled) .value("TTNN", ::tt::runtime::DeviceRuntime::TTNN) .value("TTMetal", ::tt::runtime::DeviceRuntime::TTMetal); - m.def("get_current_runtime", &tt::runtime::getCurrentRuntime, "Get the backend device runtime type"); m.def("get_available_runtimes", &tt::runtime::getAvailableRuntimes, @@ -87,6 +89,17 @@ PYBIND11_MODULE(_C, m) { py::arg("executable"), py::arg("program_index"), py::arg("inputs"), py::arg("outputs"), "Submit a binary for execution"); m.def("wait", &tt::runtime::wait, py::arg("event")); + m.def( + "get_op_output_tensor", + [](tt::runtime::OpContext &opContextHandle, + tt::runtime::CallbackContext &programContextHandle) { + tt::runtime::Tensor tensor = tt::runtime::getOpOutputTensor( + opContextHandle, programContextHandle); + return tt::runtime::getTensorData(tensor); + }, + "Get the input tensor of the op"); + m.def("get_op_debug_str", &tt::runtime::getOpDebugString, + "Get the debug string of the op"); py::class_(m, "DebugEnv") .def_static("get", &tt::runtime::debug::Env::get) @@ -96,6 +109,26 @@ PYBIND11_MODULE(_C, m) { return os.str(); }); + py::class_(m, "DebugHooks") + .def_static("get", + [](py::function func) { +#if defined(TT_RUNTIME_DEBUG) && TT_RUNTIME_DEBUG == 1 + tt::runtime::debug::Hooks::get( + [func](tt::runtime::Binary binary, + tt::runtime::CallbackContext programContext, + tt::runtime::OpContext opContext) { + func(binary, programContext, opContext); + }); +#else + tt::runtime::debug::Hooks::get(); +#endif + }) + .def("__str__", [](const tt::runtime::debug::Hooks &hooks) { + std::stringstream os; + os << hooks; + return os.str(); + }); + py::class_(m, "WorkaroundEnv") .def_static("get", &tt::runtime::workaround::Env::get) .def("__str__", [](const tt::runtime::workaround::Env &env) {