Skip to content

Commit

Permalink
[xla:gpu] Unify static and dynamic slice cases for AddressComputation…
Browse files Browse the repository at this point in the history
…FusionRewriter

PiperOrigin-RevId: 620453556
  • Loading branch information
tyb0807 authored and tensorflower-gardener committed Mar 30, 2024
1 parent c65e205 commit fd091d4
Showing 1 changed file with 25 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,7 @@ bool IsAlignedSlice(const Shape& src_shape, const Shape& dst_shape,
return true;
}

UseDefDataflowPaths GetSlicedOperandPaths(const HloInstruction* instr,
bool dynamic) {
UseDefDataflowPaths GetSlicedOperandPaths(const HloInstruction* instr) {
UseDefDataflowPaths sliced_operand_paths;

auto fusion = HloFusionAdaptor::ForComputation(instr->parent());
Expand All @@ -164,34 +163,31 @@ UseDefDataflowPaths GetSlicedOperandPaths(const HloInstruction* instr,
auto maybe_slice_adaptor =
HloFindIf({HloInstructionAdaptor(*operand)}, *fusion, [&](auto node) {
const HloInstruction* cur = &node.instruction();

// If the node is a match that has been processed, stop the traversal.
if (processed_instrs.contains(cur)) return true;

maybe_sliced_operand_path.push_back(const_cast<HloInstruction*>(cur));
if (dynamic) {
if (const auto slice_instr =
DynCast<HloDynamicSliceInstruction>(cur)) {
if (IsAlignedSlice(slice_instr->operand(0)->shape(),
slice_instr->shape(), nullptr)) {
slice_found = true;
return slice_found;
}
}
} else {
if (const auto slice_instr = DynCast<HloSliceInstruction>(cur)) {
if (IsAlignedSlice(slice_instr->operand(0)->shape(),
slice_instr->shape(), slice_instr)) {
slice_found = true;
return slice_found;
}

if (IsOpcodeAnyOf<HloOpcode::kDynamicSlice, HloOpcode::kSlice>(
node)) {
if (IsAlignedSlice(cur->operand(0)->shape(), cur->shape(),
DynCast<HloSliceInstruction>(cur))) {
slice_found = true;
return slice_found;
}
}

// TODO(vuson): lift the first restriction by considering fusing other
// uses of the operand to reuse the address computation. Only worth it
// if other uses are also custom calls though.
return cur->user_count() > 1 || !IsNoOp(cur);
});

if (maybe_slice_adaptor == std::nullopt) continue;

const auto& maybe_slice_instr = maybe_slice_adaptor->instruction();

if (slice_found || processed_instrs.contains(&maybe_slice_instr)) {
// Even in the case of stopping at a match that has been processed, we
// still need to add instructions encountered in the sliced operand path
Expand Down Expand Up @@ -415,11 +411,11 @@ absl::StatusOr<bool> AddressComputationFusionRewriter::Run(
for (HloInstruction* instr : computation->instructions()) {
if (IsLegacyCublasMatmul(*instr) ||
(!dynamic && IsCustomCall(instr, platform_name_))) {
auto sliced_operand_paths = GetSlicedOperandPaths(instr, dynamic);
UseDefDataflowPaths sliced_operand_paths =
GetSlicedOperandPaths(instr);
bool has_sliced_operand_paths = sliced_operand_paths.size() > 1;
DefUseDataflowPaths sliced_user_paths{};
if (dynamic) sliced_user_paths = GetSlicedUserPaths(instr);

DefUseDataflowPaths sliced_user_paths = GetSlicedUserPaths(instr);
bool has_sliced_user_paths =
absl::c_any_of(sliced_user_paths, [&](auto& sliced_user_path) {
return !sliced_user_path.empty();
Expand Down Expand Up @@ -464,9 +460,14 @@ absl::StatusOr<bool> AddressComputationFusionRewriter::Run(
DataflowPathsView(sliced_user_paths_view),
captures));

TF_ASSIGN_OR_RETURN(HloInstruction * fusion,
CreateFusionInstruction(module, hero, captures,
fusion_body, dynamic));
bool has_dynamic_slices =
absl::c_any_of(matched_instrs, [&](auto* instr) {
return DynCast<HloDynamicIndexInstruction>(instr) != nullptr;
});
TF_ASSIGN_OR_RETURN(
HloInstruction * fusion,
CreateFusionInstruction(module, hero, captures, fusion_body,
has_dynamic_slices));

HloComputation* parent = hero->parent();
if (fusion->shape().IsTuple()) {
Expand Down

0 comments on commit fd091d4

Please sign in to comment.