Skip to content

Commit

Permalink
Fixed failing uplifts. (#107)
Browse files Browse the repository at this point in the history
Switched to new runtime API.
Fixes #102.
  • Loading branch information
kmitrovicTT authored Dec 9, 2024
1 parent 82030b8 commit 5f32000
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 28 deletions.
41 changes: 14 additions & 27 deletions src/common/api_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,8 @@ tt_pjrt_status BufferInstance::CopyToHost(void *dst, size_t dst_size,
};

DLOG_F(INFO, "Copy to host id: %d", unique_id());
memcpy(dst, tensor().data.get(), dst_size);
tt::runtime::memcpy(dst, tensor());

auto copy_done_event = new EventInstance();
copy_done_event->OnReady(copy_done_callback, nullptr);

Expand Down Expand Up @@ -1006,45 +1007,29 @@ tt_pjrt_status
LoadedExecutableInstance::Execute(PJRT_LoadedExecutable_Execute_Args *args) {
DLOG_F(LOG_DEBUG, "LoadedExecutableInstance::Execute");

std::vector<tt::runtime::Tensor> rt_inputs;
std::vector<tt::runtime::Tensor> rt_outputs;

auto [system_desc, chip_ids] = tt::runtime::getCurrentSystemDesc();
int dev_0 = chip_ids[0];
auto device = tt::runtime::openDevice({dev_0});

rt_inputs.reserve(args->num_args);
assert(args->num_devices == 1);
int dev_index = 0;
tt::runtime::Binary binary(image_->binary);

std::vector<tt::runtime::Tensor> rt_inputs;
rt_inputs.reserve(args->num_args);

for (size_t i = 0; i < args->num_args; ++i) {
auto *buffer = BufferInstance::Unwrap(args->argument_lists[dev_index][i]);
rt_inputs.emplace_back(buffer->tensor());
DLOG_F(INFO, "Runtime input id: %d", buffer->unique_id());
}
auto output_specs = binary.getProgramOutputs(0);
for (auto output_spec : output_specs) {
int volume = 1;
for (auto dim : output_spec.shape) {
volume *= dim;
}
int size = volume * output_spec.itemsize;
// TODO What to do with stride
void *data = malloc(size);

#ifdef DEBUG
memset(data, 0, size);
#endif
std::shared_ptr<void> data_ptr(const_cast<void *>(data), [](void *) {});
tt::runtime::Tensor tensor = tt::runtime::createTensor(
data_ptr, output_spec.shape, output_spec.stride, output_spec.itemsize,
output_spec.dataType);
rt_outputs.emplace_back(tensor);
}

tt::runtime::Event event =
tt::runtime::submit(device, binary, 0, rt_inputs, rt_outputs);
(void)event;
std::vector<tt::runtime::Tensor> rt_outputs =
tt::runtime::submit(device, binary, 0, rt_inputs);
std::vector<tt::runtime::TensorDesc> output_specs =
binary.getProgramOutputs(0);

assert(rt_outputs.size() == output_specs.size());

for (size_t i = 0; i < output_specs.size(); ++i) {
auto result_buffer = std::make_unique<BufferInstance>(
Expand All @@ -1055,9 +1040,11 @@ LoadedExecutableInstance::Execute(PJRT_LoadedExecutable_Execute_Args *args) {
DLOG_F(INFO, "Runtime output id: %d", result_buffer->unique_id());
args->output_lists[dev_index][i] = *(result_buffer.release());
}

if (args->device_complete_events) {
args->device_complete_events[dev_index] = *(new EventInstance());
}

tt::runtime::closeDevice(device);

return tt_pjrt_status::kSuccess;
Expand Down
2 changes: 1 addition & 1 deletion third_party/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0
#

set(TT_MLIR_VERSION "0640c7c8366b9b44fd6e210eb7ca76d1c2ec4121")
set(TT_MLIR_VERSION "223d24444c419dec906f87749f161f03b321fce7")
set(LOGURU_VERSION "4adaa185883e3c04da25913579c451d3c32cfac1")

if (TOOLCHAIN STREQUAL "ON")
Expand Down

0 comments on commit 5f32000

Please sign in to comment.