From fd091d499a90bd73e9eedcd648bfe35229e97ce3 Mon Sep 17 00:00:00 2001 From: Son Tuan Vu Date: Sat, 30 Mar 2024 02:39:15 -0700 Subject: [PATCH] [xla:gpu] Unify static and dynamic slice cases for AddressComputationFusionRewriter PiperOrigin-RevId: 620453556 --- .../address_computation_fusion_rewriter.cc | 49 ++++++++++--------- 1 file changed, 25 insertions(+), 24 deletions(-) diff --git a/third_party/xla/xla/service/gpu/address_computation_fusion_rewriter.cc b/third_party/xla/xla/service/gpu/address_computation_fusion_rewriter.cc index afb429b1942ab3..03317a8f09e166 100644 --- a/third_party/xla/xla/service/gpu/address_computation_fusion_rewriter.cc +++ b/third_party/xla/xla/service/gpu/address_computation_fusion_rewriter.cc @@ -138,8 +138,7 @@ bool IsAlignedSlice(const Shape& src_shape, const Shape& dst_shape, return true; } -UseDefDataflowPaths GetSlicedOperandPaths(const HloInstruction* instr, - bool dynamic) { +UseDefDataflowPaths GetSlicedOperandPaths(const HloInstruction* instr) { UseDefDataflowPaths sliced_operand_paths; auto fusion = HloFusionAdaptor::ForComputation(instr->parent()); @@ -164,34 +163,31 @@ UseDefDataflowPaths GetSlicedOperandPaths(const HloInstruction* instr, auto maybe_slice_adaptor = HloFindIf({HloInstructionAdaptor(*operand)}, *fusion, [&](auto node) { const HloInstruction* cur = &node.instruction(); + // If the node is a match that has been processed, stop the traversal. if (processed_instrs.contains(cur)) return true; + maybe_sliced_operand_path.push_back(const_cast(cur)); - if (dynamic) { - if (const auto slice_instr = - DynCast(cur)) { - if (IsAlignedSlice(slice_instr->operand(0)->shape(), - slice_instr->shape(), nullptr)) { - slice_found = true; - return slice_found; - } - } - } else { - if (const auto slice_instr = DynCast(cur)) { - if (IsAlignedSlice(slice_instr->operand(0)->shape(), - slice_instr->shape(), slice_instr)) { - slice_found = true; - return slice_found; - } + + if (IsOpcodeAnyOf( + node)) { + if (IsAlignedSlice(cur->operand(0)->shape(), cur->shape(), + DynCast(cur))) { + slice_found = true; + return slice_found; } } + // TODO(vuson): lift the first restriction by considering fusing other // uses of the operand to reuse the address computation. Only worth it // if other uses are also custom calls though. return cur->user_count() > 1 || !IsNoOp(cur); }); + if (maybe_slice_adaptor == std::nullopt) continue; + const auto& maybe_slice_instr = maybe_slice_adaptor->instruction(); + if (slice_found || processed_instrs.contains(&maybe_slice_instr)) { // Even in the case of stopping at a match that has been processed, we // still need to add instructions encountered in the sliced operand path @@ -415,11 +411,11 @@ absl::StatusOr AddressComputationFusionRewriter::Run( for (HloInstruction* instr : computation->instructions()) { if (IsLegacyCublasMatmul(*instr) || (!dynamic && IsCustomCall(instr, platform_name_))) { - auto sliced_operand_paths = GetSlicedOperandPaths(instr, dynamic); + UseDefDataflowPaths sliced_operand_paths = + GetSlicedOperandPaths(instr); bool has_sliced_operand_paths = sliced_operand_paths.size() > 1; - DefUseDataflowPaths sliced_user_paths{}; - if (dynamic) sliced_user_paths = GetSlicedUserPaths(instr); + DefUseDataflowPaths sliced_user_paths = GetSlicedUserPaths(instr); bool has_sliced_user_paths = absl::c_any_of(sliced_user_paths, [&](auto& sliced_user_path) { return !sliced_user_path.empty(); @@ -464,9 +460,14 @@ absl::StatusOr AddressComputationFusionRewriter::Run( DataflowPathsView(sliced_user_paths_view), captures)); - TF_ASSIGN_OR_RETURN(HloInstruction * fusion, - CreateFusionInstruction(module, hero, captures, - fusion_body, dynamic)); + bool has_dynamic_slices = + absl::c_any_of(matched_instrs, [&](auto* instr) { + return DynCast(instr) != nullptr; + }); + TF_ASSIGN_OR_RETURN( + HloInstruction * fusion, + CreateFusionInstruction(module, hero, captures, fusion_body, + has_dynamic_slices)); HloComputation* parent = hero->parent(); if (fusion->shape().IsTuple()) {