From c637ed1ca9e65a5f02d8f287cfc025d3e9b6cdb0 Mon Sep 17 00:00:00 2001 From: Ionel Gog Date: Fri, 13 Dec 2024 10:21:42 -0800 Subject: [PATCH] [IFRT] Add option to compile IFRT IR atom programs using Sdy PiperOrigin-RevId: 705924258 --- xla/python/ifrt/ir/constants.h | 5 +++ .../ir/tests/ifrt_compile_atom_program.mlir | 31 +++++++++++++++++++ xla/python/ifrt/ir/transforms/BUILD | 3 ++ .../ifrt_compile_atom_program_pass.cc | 28 +++++++++++++++++ 4 files changed, 67 insertions(+) diff --git a/xla/python/ifrt/ir/constants.h b/xla/python/ifrt/ir/constants.h index 52b22e7b9c5dd..512b22259fdc0 100644 --- a/xla/python/ifrt/ir/constants.h +++ b/xla/python/ifrt/ir/constants.h @@ -57,6 +57,11 @@ inline constexpr llvm::StringLiteral kIfrtMemoryKindAttrName = inline constexpr llvm::StringLiteral kIfrtEntryFunctionAttrName = "ifrt.entry_function"; +// Name of UnitAttr on CallOp used to indicate that an atom program was +// partitioned by the Sdy partitioner. +inline constexpr llvm::StringLiteral kIsSdyPartitioned = + "ifrt.is_sdy_partitioned"; + inline constexpr llvm::StringLiteral kCalleeMainFuncName = "main"; // Name of StringAttr used to store the HloSharding. diff --git a/xla/python/ifrt/ir/tests/ifrt_compile_atom_program.mlir b/xla/python/ifrt/ir/tests/ifrt_compile_atom_program.mlir index b99e0f9a43b79..22257730e01d5 100644 --- a/xla/python/ifrt/ir/tests/ifrt_compile_atom_program.mlir +++ b/xla/python/ifrt/ir/tests/ifrt_compile_atom_program.mlir @@ -25,3 +25,34 @@ module @call_hlo { } } } + +// ----- + +!array = !ifrt.array, + #ifrt.sharding_param<2x1 to [0] on 2>, [0,1]> +// CHECK-LABEL: @call_hlo_sdy_lowered +module @call_hlo_sdy_lowered attributes { + mhlo.frontend_attributes = { + xla.sdy.meshes ="{mesh = #sdy.mesh<[\\\22x\\\22=2]>}"}} { + func.func @main(%arg0: !array) -> !array attributes {ifrt.function} { + // CHECK: ifrt.CallLoadedExecutable @fake_component__fake_method_1(%arg0) + %0, %ctrl_0 = ifrt.Call @add_one::@main(%arg0) on devices [0,1] + {ifrt.module_type = "xla", ifrt.is_sdy_partitioned} : (!array) -> !array + return %0 : !array + } + + // module @add_one attributes {mhlo.frontend_attributes = {xla.sdy.meshes = "{mesh = #sdy.mesh<[\\\22x\\\22=2]>}"}, sym_visibility = "private"} + // CHECK: ifrt.LoadedExecutable @fake_component__fake_method + // CHECK-SAME: on devices [0, 1] + // CHECK: (!ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]>) + // CHECK-SAME: -> !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]> + module @add_one attributes {sym_visibility = "private"} { + func.func private @main( + %arg0: tensor<2x2xi32> {mhlo.sharding = "{devices=[2,1]<=[2]}"}) + -> (tensor<2x2xi32> {mhlo.sharding = "{devices=[2,1]<=[2]}"}) { + %0 = mhlo.constant dense<1> : tensor<2x2xi32> + %1 = mhlo.add %arg0, %0 : tensor<2x2xi32> + return %1 : tensor<2x2xi32> + } + } +} diff --git a/xla/python/ifrt/ir/transforms/BUILD b/xla/python/ifrt/ir/transforms/BUILD index 2ba4b8501f645..c7bf2bd0c2702 100644 --- a/xla/python/ifrt/ir/transforms/BUILD +++ b/xla/python/ifrt/ir/transforms/BUILD @@ -85,6 +85,8 @@ cc_library( "//xla/service:compilation_environments", "//xla/service:computation_placer_hdr", "//xla/service:hlo_proto_cc", + "//xla/service/spmd/shardy:constants", + "//xla/service/spmd/shardy:utils", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_map", @@ -106,6 +108,7 @@ cc_library( "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", + "@shardy//shardy/dialect/sdy/ir:dialect", "@stablehlo//:register", "@stablehlo//:stablehlo_ops", "@stablehlo//:stablehlo_serialization", diff --git a/xla/python/ifrt/ir/transforms/ifrt_compile_atom_program_pass.cc b/xla/python/ifrt/ir/transforms/ifrt_compile_atom_program_pass.cc index 04f005ff73cb4..216fb974c024b 100644 --- a/xla/python/ifrt/ir/transforms/ifrt_compile_atom_program_pass.cc +++ b/xla/python/ifrt/ir/transforms/ifrt_compile_atom_program_pass.cc @@ -26,6 +26,7 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" @@ -42,6 +43,7 @@ limitations under the License. #include "mlir/Pass/PassRegistry.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/TypeID.h" +#include "shardy/dialect/sdy/ir/dialect.h" #include "stablehlo/dialect/StablehloOps.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/python/ifrt/compiler.h" @@ -52,6 +54,8 @@ limitations under the License. #include "xla/python/ifrt/ir/transforms/passes.h" #include "xla/python/ifrt/ir/transforms/utils.h" #include "xla/service/hlo.pb.h" +#include "xla/service/spmd/shardy/constants.h" +#include "xla/service/spmd/shardy/utils.h" namespace xla { namespace ifrt { @@ -83,6 +87,7 @@ class IfrtCompileAtomProgramPass void getDependentDialects(::mlir::DialectRegistry& registry) const override { registry.insert(); registry.insert(); + registry.insert(); } void runOnOperation() override; @@ -108,6 +113,14 @@ void IfrtCompileAtomProgramPass::runOnOperation() { // Map from the hash of the CallOp to the compile future. llvm::DenseMap call_to_compile_futures; mlir::ModuleOp module_op = getOperation(); + + mlir::Attribute meshes_round_trip_attr; + // TODO: icgog - This attribute will be deleted in the IFRT -> VIFRT + // legalization. Fix in order to be able to use Sdy with VIFRT. + if (auto front_end_attr = xla::sdy::getFrontendAttrs(module_op)) { + meshes_round_trip_attr = front_end_attr.get(xla::sdy::kMeshesRoundTripAttr); + } + // Walk and dispatch the compilations in parallel. auto compile_result = module_op.walk([&](CallOp call_op) -> mlir::WalkResult { @@ -125,6 +138,21 @@ void IfrtCompileAtomProgramPass::runOnOperation() { << callee.getSymName() << ". Actual callee parent: " << callee->getParentOp()->getName(); } + + if (call_op->hasAttr(kIsSdyPartitioned)) { + // Add the meshes roundtrip attribute to the callee module if the + // atom program was partitioned with sdy. + if (!meshes_round_trip_attr) { + return call_op.emitOpError() + << "requires meshes roundtrip attribute to be set on the " + "program module if the atom program was partitioned " + "with sdy."; + } + xla::sdy::setFrontendAttribute( + callee_module, xla::sdy::kMeshesRoundTripAttr, + meshes_round_trip_attr, /*escapeAttr=*/false); + } + absl::StatusOr compile_future = atom_program_compiler_.CompileModule(call_op, callee_module); if (!compile_future.ok()) {