From 5f3200001213c476f8f17c92feb1b47ceca8a062 Mon Sep 17 00:00:00 2001 From: Kristijan Mitrovic Date: Mon, 9 Dec 2024 15:11:20 +0100 Subject: [PATCH] Fixed failing uplifts. (#107) Switched to new runtime API. Fixes #102. --- src/common/api_impl.cc | 41 +++++++++++++------------------------- third_party/CMakeLists.txt | 2 +- 2 files changed, 15 insertions(+), 28 deletions(-) diff --git a/src/common/api_impl.cc b/src/common/api_impl.cc index f4b66b7..2c8f817 100644 --- a/src/common/api_impl.cc +++ b/src/common/api_impl.cc @@ -361,7 +361,8 @@ tt_pjrt_status BufferInstance::CopyToHost(void *dst, size_t dst_size, }; DLOG_F(INFO, "Copy to host id: %d", unique_id()); - memcpy(dst, tensor().data.get(), dst_size); + tt::runtime::memcpy(dst, tensor()); + auto copy_done_event = new EventInstance(); copy_done_event->OnReady(copy_done_callback, nullptr); @@ -1006,45 +1007,29 @@ tt_pjrt_status LoadedExecutableInstance::Execute(PJRT_LoadedExecutable_Execute_Args *args) { DLOG_F(LOG_DEBUG, "LoadedExecutableInstance::Execute"); - std::vector rt_inputs; - std::vector rt_outputs; - auto [system_desc, chip_ids] = tt::runtime::getCurrentSystemDesc(); int dev_0 = chip_ids[0]; auto device = tt::runtime::openDevice({dev_0}); - rt_inputs.reserve(args->num_args); assert(args->num_devices == 1); int dev_index = 0; tt::runtime::Binary binary(image_->binary); + + std::vector rt_inputs; + rt_inputs.reserve(args->num_args); + for (size_t i = 0; i < args->num_args; ++i) { auto *buffer = BufferInstance::Unwrap(args->argument_lists[dev_index][i]); rt_inputs.emplace_back(buffer->tensor()); DLOG_F(INFO, "Runtime input id: %d", buffer->unique_id()); } - auto output_specs = binary.getProgramOutputs(0); - for (auto output_spec : output_specs) { - int volume = 1; - for (auto dim : output_spec.shape) { - volume *= dim; - } - int size = volume * output_spec.itemsize; - // TODO What to do with stride - void *data = malloc(size); - -#ifdef DEBUG - memset(data, 0, size); -#endif - std::shared_ptr data_ptr(const_cast(data), [](void *) {}); - tt::runtime::Tensor tensor = tt::runtime::createTensor( - data_ptr, output_spec.shape, output_spec.stride, output_spec.itemsize, - output_spec.dataType); - rt_outputs.emplace_back(tensor); - } - tt::runtime::Event event = - tt::runtime::submit(device, binary, 0, rt_inputs, rt_outputs); - (void)event; + std::vector rt_outputs = + tt::runtime::submit(device, binary, 0, rt_inputs); + std::vector output_specs = + binary.getProgramOutputs(0); + + assert(rt_outputs.size() == output_specs.size()); for (size_t i = 0; i < output_specs.size(); ++i) { auto result_buffer = std::make_unique( @@ -1055,9 +1040,11 @@ LoadedExecutableInstance::Execute(PJRT_LoadedExecutable_Execute_Args *args) { DLOG_F(INFO, "Runtime output id: %d", result_buffer->unique_id()); args->output_lists[dev_index][i] = *(result_buffer.release()); } + if (args->device_complete_events) { args->device_complete_events[dev_index] = *(new EventInstance()); } + tt::runtime::closeDevice(device); return tt_pjrt_status::kSuccess; diff --git a/third_party/CMakeLists.txt b/third_party/CMakeLists.txt index 326ebc9..026a7be 100644 --- a/third_party/CMakeLists.txt +++ b/third_party/CMakeLists.txt @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 # -set(TT_MLIR_VERSION "0640c7c8366b9b44fd6e210eb7ca76d1c2ec4121") +set(TT_MLIR_VERSION "223d24444c419dec906f87749f161f03b321fce7") set(LOGURU_VERSION "4adaa185883e3c04da25913579c451d3c32cfac1") if (TOOLCHAIN STREQUAL "ON")