Skip to content

Commit

Permalink
Added support for scalars
Browse files Browse the repository at this point in the history
  • Loading branch information
ajakovljevicTT committed Dec 2, 2024
1 parent d770c94 commit 694b9d0
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 9 deletions.
8 changes: 6 additions & 2 deletions src/common/api_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1047,9 +1047,13 @@ LoadedExecutableInstance::Execute(PJRT_LoadedExecutable_Execute_Args *args) {
(void)event;

for (size_t i = 0; i < output_specs.size(); ++i) {
bool is_scalar = client_.get_module_builder()->is_output_scalar(i);
// PJRT expects an empty shape for scalars.
std::vector<std::uint32_t> output_shape =
is_scalar ? std::vector<std::uint32_t>() : output_specs[i].shape;
auto result_buffer = std::make_unique<BufferInstance>(
*this->addressable_devices_[dev_index], rt_outputs[i],
output_specs[i].shape, output_specs[i].stride);
*this->addressable_devices_[dev_index], rt_outputs[i], output_shape,
output_specs[i].stride);
result_buffer->setType(
convertElementTypeToBufferType(output_specs[i].dataType));
DLOG_F(INFO, "Runtime output id: %d", result_buffer->unique_id());
Expand Down
5 changes: 5 additions & 0 deletions src/common/api_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,11 @@ class ClientInstance {
// Advances the timeline, returning (current, next) time point values.
std::tuple<uint64_t, uint64_t> AdvanceTimeline();

// Returns the module builder used for this ClientInstance.
const ModuleBuilder *get_module_builder() const {
return module_builder_.get();
}

protected:
std::string cached_platform_name_;
std::string cached_platform_version_;
Expand Down
33 changes: 33 additions & 0 deletions src/common/module_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "mlir/Dialect/Func/Extensions/AllExtensions.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/Parser/Parser.h"
Expand Down Expand Up @@ -65,6 +66,36 @@ ModuleBuilder::ModuleBuilder()
m_context->appendDialectRegistry(registry);
}

static bool isScalarType(mlir::Type type) {
if (mlir::isa<mlir::FloatType>(type) || mlir::isa<mlir::IntegerType>(type)) {
return true;
}
if (auto tensorType = mlir::dyn_cast<mlir::RankedTensorType>(type)) {
return tensorType.getRank() == 0;
}
return false;
}

void ModuleBuilder::collectOutputTypes(mlir::ModuleOp &&module) {
m_is_output_scalar.clear();
for (auto &op : module.getOps()) {
if (auto funcOp = mlir::cast<mlir::func::FuncOp>(op)) {
// We care only for return ops of public functions, as that are the ones
// that will produce results in the flatbuffer.
if (funcOp.isPublic()) {
funcOp.walk([&](mlir::Operation *op) {
if (mlir::func::ReturnOp return_op =
mlir::dyn_cast<mlir::func::ReturnOp>(op)) {
for (auto operand : op->getOperands()) {
m_is_output_scalar.push_back(isScalarType(operand.getType()));
}
}
});
}
}
}
}

tt_pjrt_status ModuleBuilder::buildModule(const std::string_view &code,
const std::string_view &format) {
DLOG_F(LOG_DEBUG, "ModuleBuilder::buildModule");
Expand Down Expand Up @@ -134,6 +165,8 @@ void ModuleBuilder::convertFromVHLOToSHLO(
return;
}

collectOutputTypes(mlir_module.get());

DLOG_F(LOG_DEBUG, "SHLO Module:");
print_module(mlir_module);
}
Expand Down
9 changes: 9 additions & 0 deletions src/common/module_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ class ModuleBuilder {

size_t getNumOutputs() const { return m_num_outputs; };

bool is_output_scalar(int index) const { return m_is_output_scalar[index]; }

private:
// Creates VHLO module from the input program code.
mlir::OwningOpRef<mlir::ModuleOp>
Expand All @@ -51,6 +53,10 @@ class ModuleBuilder {
void
createFlatbufferBinary(const mlir::OwningOpRef<mlir::ModuleOp> &mlir_module);

// Fills up the m_is_output_scalar array with information is the output type
// scalar or not.
void collectOutputTypes(mlir::ModuleOp &&module);

// Prints module to console for debug purposes.
static void print_module(mlir::OwningOpRef<mlir::ModuleOp> &mlir_module);

Expand All @@ -68,6 +74,9 @@ class ModuleBuilder {

// Holds status of the last builder action.
tt_pjrt_status m_status;

// For every output, holds if the type is a scalar or not.
std::vector<bool> m_is_output_scalar;
};

} // namespace tt::pjrt
Expand Down
7 changes: 0 additions & 7 deletions tests/infrastructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,6 @@ def compare_tensor_to_golden(
):
ret = True

# TODO (issue #81): Remove these reshapes once the PJRT can handle scalars.
if tensor.ndim == 0:
tensor = tensor.reshape((1,))
if golden.ndim == 0:
with run_on_cpu():
golden = golden.reshape((1,))

if tensor.device != golden.device:
tensor = jax.device_put(tensor, golden.device)

Expand Down

0 comments on commit 694b9d0

Please sign in to comment.