Skip to content

Commit

Permalink
[IFRT] Add option to compile IFRT IR atom programs using Sdy
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 705948446
  • Loading branch information
ICGog authored and Google-ML-Automation committed Dec 13, 2024
1 parent 781a990 commit b2949d5
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 0 deletions.
5 changes: 5 additions & 0 deletions xla/python/ifrt/ir/constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ inline constexpr llvm::StringLiteral kIfrtMemoryKindAttrName =
inline constexpr llvm::StringLiteral kIfrtEntryFunctionAttrName =
"ifrt.entry_function";

// Name of UnitAttr on CallOp used to indicate that an atom program was
// partitioned by the Sdy partitioner.
inline constexpr llvm::StringLiteral kIsSdyPartitioned =
"ifrt.is_sdy_partitioned";

inline constexpr llvm::StringLiteral kCalleeMainFuncName = "main";

// Name of StringAttr used to store the HloSharding.
Expand Down
31 changes: 31 additions & 0 deletions xla/python/ifrt/ir/tests/ifrt_compile_atom_program.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,34 @@ module @call_hlo {
}
}
}

// -----

!array = !ifrt.array<tensor<2x2xi32>,
#ifrt.sharding_param<2x1 to [0] on 2>, [0,1]>
// CHECK-LABEL: @call_hlo_sdy_lowered
module @call_hlo_sdy_lowered attributes {
mhlo.frontend_attributes = {
xla.sdy.meshes ="{mesh = #sdy.mesh<[\\\22x\\\22=2]>}"}} {
func.func @main(%arg0: !array) -> !array attributes {ifrt.function} {
// CHECK: ifrt.CallLoadedExecutable @fake_component__fake_method_1(%arg0)
%0, %ctrl_0 = ifrt.Call @add_one::@main(%arg0) on devices [0,1]
{ifrt.module_type = "xla", ifrt.is_sdy_partitioned} : (!array) -> !array
return %0 : !array
}

// module @add_one attributes {mhlo.frontend_attributes = {xla.sdy.meshes = "{mesh = #sdy.mesh<[\\\22x\\\22=2]>}"}, sym_visibility = "private"}
// CHECK: ifrt.LoadedExecutable @fake_component__fake_method
// CHECK-SAME: on devices [0, 1]
// CHECK: (!ifrt.array<tensor<2x2xi32>, #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]>)
// CHECK-SAME: -> !ifrt.array<tensor<2x2xi32>, #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]>
module @add_one attributes {sym_visibility = "private"} {
func.func private @main(
%arg0: tensor<2x2xi32> {mhlo.sharding = "{devices=[2,1]<=[2]}"})
-> (tensor<2x2xi32> {mhlo.sharding = "{devices=[2,1]<=[2]}"}) {
%0 = mhlo.constant dense<1> : tensor<2x2xi32>
%1 = mhlo.add %arg0, %0 : tensor<2x2xi32>
return %1 : tensor<2x2xi32>
}
}
}
3 changes: 3 additions & 0 deletions xla/python/ifrt/ir/transforms/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ cc_library(
"//xla/service:compilation_environments",
"//xla/service:computation_placer_hdr",
"//xla/service:hlo_proto_cc",
"//xla/service/spmd/shardy:constants",
"//xla/service/spmd/shardy:utils",
"@com_google_absl//absl/cleanup",
"@com_google_absl//absl/container:btree",
"@com_google_absl//absl/container:flat_hash_map",
Expand All @@ -106,6 +108,7 @@ cc_library(
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
"@llvm-project//mlir:Transforms",
"@shardy//shardy/dialect/sdy/ir:dialect",
"@stablehlo//:register",
"@stablehlo//:stablehlo_ops",
"@stablehlo//:stablehlo_serialization",
Expand Down
28 changes: 28 additions & 0 deletions xla/python/ifrt/ir/transforms/ifrt_compile_atom_program_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ limitations under the License.
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Casting.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
Expand All @@ -42,6 +43,7 @@ limitations under the License.
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/TypeID.h"
#include "shardy/dialect/sdy/ir/dialect.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
#include "xla/python/ifrt/compiler.h"
Expand All @@ -52,6 +54,8 @@ limitations under the License.
#include "xla/python/ifrt/ir/transforms/passes.h"
#include "xla/python/ifrt/ir/transforms/utils.h"
#include "xla/service/hlo.pb.h"
#include "xla/service/spmd/shardy/constants.h"
#include "xla/service/spmd/shardy/utils.h"

namespace xla {
namespace ifrt {
Expand Down Expand Up @@ -83,6 +87,7 @@ class IfrtCompileAtomProgramPass
void getDependentDialects(::mlir::DialectRegistry& registry) const override {
registry.insert<mlir::mhlo::MhloDialect>();
registry.insert<mlir::stablehlo::StablehloDialect>();
registry.insert<mlir::sdy::SdyDialect>();
}

void runOnOperation() override;
Expand All @@ -108,6 +113,14 @@ void IfrtCompileAtomProgramPass::runOnOperation() {
// Map from the hash of the CallOp to the compile future.
llvm::DenseMap<CallOp, CompileFuture, IfrtCallOpInfo> call_to_compile_futures;
mlir::ModuleOp module_op = getOperation();

mlir::Attribute meshes_round_trip_attr;
// TODO: icgog - This attribute will be deleted in the IFRT -> VIFRT
// legalization. Fix in order to be able to use Sdy with VIFRT.
if (auto front_end_attr = xla::sdy::getFrontendAttrs(module_op)) {
meshes_round_trip_attr = front_end_attr.get(xla::sdy::kMeshesRoundTripAttr);
}

// Walk and dispatch the compilations in parallel.
auto compile_result =
module_op.walk([&](CallOp call_op) -> mlir::WalkResult {
Expand All @@ -125,6 +138,21 @@ void IfrtCompileAtomProgramPass::runOnOperation() {
<< callee.getSymName() << ". Actual callee parent: "
<< callee->getParentOp()->getName();
}

if (call_op->hasAttr(kIsSdyPartitioned)) {
// Add the meshes roundtrip attribute to the callee module if the
// atom program was partitioned with sdy.
if (!meshes_round_trip_attr) {
return call_op.emitOpError()
<< "requires meshes roundtrip attribute to be set on the "
"program module if the atom program was partitioned "
"with sdy.";
}
xla::sdy::setFrontendAttribute(
callee_module, xla::sdy::kMeshesRoundTripAttr,
meshes_round_trip_attr, /*escapeAttr=*/false);
}

absl::StatusOr<CompileFuture> compile_future =
atom_program_compiler_.CompileModule(call_op, callee_module);
if (!compile_future.ok()) {
Expand Down

0 comments on commit b2949d5

Please sign in to comment.