Skip to content

Commit

Permalink
Addressed more comments
Browse files Browse the repository at this point in the history
  • Loading branch information
ajakovljevicTT committed Dec 11, 2024
1 parent f3f317f commit 65ece7a
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 17 deletions.
2 changes: 1 addition & 1 deletion src/common/api_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ class ClientInstance {
std::tuple<uint64_t, uint64_t> AdvanceTimeline();

// Checks if the output on the i-th index is a scalar.
bool isOutputScalar(const int index) const {
bool isOutputScalar(const size_t index) const {
return module_builder_->isOutputScalar(index);
}

Expand Down
28 changes: 13 additions & 15 deletions src/common/module_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ ModuleBuilder::ModuleBuilder()
m_context->appendDialectRegistry(registry);
}

bool ModuleBuilder::isOutputScalar(int index) const {
assert(index < m_is_output_scalar.size() && "Index of output out of range.");
bool ModuleBuilder::isOutputScalar(size_t index) const {
assert(index < m_is_output_scalar.size() && "Output index out of range");
return m_is_output_scalar[index];
}

Expand Down Expand Up @@ -151,18 +151,14 @@ void ModuleBuilder::collectOutputTypes(

// We care only about return ops of public functions, as that are the ones
// that will produce results in the flatbuffer.
if (funcOp && funcOp.isPublic()) {
for (const mlir::Type &returnType :
funcOp.getFunctionType().getResults()) {
m_is_output_scalar.push_back(isScalarType(returnType));
}
if (!funcOp) {
return;
}

if (moduleOp) {
if (moduleOp == module.get()) {
return;
}
collectOutputTypes(moduleOp);
if (!funcOp.isPublic()) {
return;
}
for (const mlir::Type &returnType : funcOp.getFunctionType().getResults()) {
m_is_output_scalar.push_back(isScalarType(returnType));
}
});
}
Expand Down Expand Up @@ -234,9 +230,11 @@ void ModuleBuilder::createFlatbufferBinary(
m_num_outputs = runtime_binary_handle.getProgramOutputs(0).size();

if (m_num_outputs != m_is_output_scalar.size()) {
DLOG_F(ERROR, "The number of return types does not match ");
DLOG_F(ERROR,
"Created flatbuffer binary contains different number of outputs %d "
"than expected %d",
m_num_outputs, m_is_output_scalar.size());
m_status = tt_pjrt_status::kInternal;
return;
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/common/module_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class ModuleBuilder {

size_t getNumOutputs() const { return m_num_outputs; };

bool isOutputScalar(int index) const;
bool isOutputScalar(size_t index) const;

private:
// Creates VHLO module from the input program code.
Expand Down

0 comments on commit 65ece7a

Please sign in to comment.