Skip to content

Commit

Permalink
[xla:gpu][NFC] Make lambdas static functions for better reusability
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 620511999
  • Loading branch information
tyb0807 authored and tensorflower-gardener committed Mar 30, 2024
1 parent fd091d4 commit 92b03bd
Show file tree
Hide file tree
Showing 2 changed files with 179 additions and 144 deletions.
1 change: 1 addition & 0 deletions third_party/xla/xla/service/gpu/fusions/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ cc_library(
"@com_google_absl//absl/log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AsmParser",
"@llvm-project//mlir:IR",
Expand Down
322 changes: 178 additions & 144 deletions third_party/xla/xla/service/gpu/fusions/custom.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ limitations under the License.
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "absl/types/span.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "mlir/AsmParser/AsmParser.h" // from @llvm-project
Expand Down Expand Up @@ -137,6 +138,130 @@ absl::StatusOr<BufferAllocation::Slice> GetSliceWithUpdatedOffsetAndSize(
return BufferAllocation::Slice(orig_slice.allocation(), offset, size);
}

absl::StatusOr<BufferAllocation::Slice> GetOperandSlice(
const BufferAssignment& buffer_assignment, const HloFusionAdaptor& adaptor,
const HloInstruction& fusion_instr, const HloInstruction& start_instr,
std::vector<HloInstruction*>& slice_instrs, const ShapeIndex& shape_idx,
unsigned arg_idx) {
auto slice_adaptor =
HloFindIf({HloInstructionAdaptor(start_instr)}, adaptor, [](auto node) {
return IsOpcodeAnyOf<HloOpcode::kDynamicSlice, HloOpcode::kSlice>(node);
});
if (slice_adaptor.has_value()) {
auto* slice_instr =
const_cast<HloInstruction*>(&slice_adaptor->instruction());

if (!IsContiguousSlice(slice_instr->operand(0)->shape(),
slice_instr->shape())) {
return absl::InternalError(
"DynamicAddressComputationFusion only handles contiguous slices "
"currently");
}

slice_instrs[arg_idx] = slice_instr;

const auto* param = Cast<HloParameterInstruction>(slice_instr->operand(0));
TF_ASSIGN_OR_RETURN(
BufferAllocation::Slice orig_slice,
GetAllocationSlice(buffer_assignment,
fusion_instr.operand(param->parameter_number()),
shape_idx));

if (auto* static_slice = DynCast<HloSliceInstruction>(slice_instr)) {
// Update static slices.
const Shape& src_shape = static_slice->operand(0)->shape();
const Shape& dst_shape = static_slice->shape();
int64_t size = ShapeUtil::ByteSizeOf(dst_shape);

// Given this slice
// f16[1,4,8]{2,1,0} slice(f16[2,8,8]{2,1,0}),
// slice={[1:2], [4:8], [0:8]}
//
// The offset of the slice should be:
// slice_starts(0) * 8 * 8 * sizeof(f16) +
// slice_starts(1) * 8 * sizeof(f16)
int64_t offset = orig_slice.offset();
for (auto [start, stride] :
llvm::zip(static_slice->slice_starts(),
*ShapeUtil::ByteStrides(src_shape))) {
offset += start * stride;
}

return BufferAllocation::Slice(orig_slice.allocation(), offset, size);
}

return orig_slice;
}

const auto* param = DynCast<HloParameterInstruction>(&start_instr);
return GetAllocationSlice(buffer_assignment,
fusion_instr.operand(param->parameter_number()),
shape_idx);
}

absl::Status CollectSliceInfo(
const BufferAssignment& buffer_assignment,
const HloInstruction& fusion_instr,
absl::Span<HloInstruction*> slice_instrs,
std::vector<std::optional<std::vector<BufferAllocation::Slice>>>&
offset_buffer_indices,
std::vector<std::optional<Shape>>& orig_shapes,
std::vector<std::optional<Shape>>& sliced_shapes,
std::vector<std::optional<uint64_t>>& offset_byte_sizes, unsigned arg_idx) {
auto* slice_instr =
DynCastOrNull<HloDynamicIndexInstruction>(slice_instrs[arg_idx]);
if (slice_instr == nullptr) {
return absl::OkStatus();
}

std::vector<BufferAllocation::Slice> offset_slices;
for (auto idx_op : slice_instr->index_operands()) {
const auto* param = Cast<HloParameterInstruction>(idx_op);
TF_ASSIGN_OR_RETURN(
auto offset_slice,
GetAllocationSlice(buffer_assignment,
fusion_instr.operand(param->parameter_number()),
/*index=*/{}));
offset_slices.push_back(offset_slice);
}
offset_buffer_indices[arg_idx] = std::move(offset_slices);
orig_shapes[arg_idx] = slice_instr->operand(0)->shape();
sliced_shapes[arg_idx] = DynCast<HloDynamicSliceInstruction>(slice_instr)
? slice_instr->shape()
: slice_instr->operand(1)->shape();
offset_byte_sizes[arg_idx] = ShapeUtil::ByteSizeOfPrimitiveType(
slice_instr->index_operands().front()->shape().element_type());

return absl::OkStatus();
}

absl::StatusOr<BufferAllocation::Slice> GetResultSlice(
const BufferAssignment& buffer_assignment, const HloFusionAdaptor& adaptor,
const HloInstruction& fusion_instr, const HloInstruction& start_instr,
std::vector<HloInstruction*>& slice_instrs, const ShapeIndex& shape_idx,
unsigned arg_idx) {
auto slice_adaptor = HloFindIf(
{HloInstructionAdaptor(start_instr)}, adaptor,
[](auto node) { return node.opcode() == HloOpcode::kDynamicUpdateSlice; },
false);
if (slice_adaptor.has_value()) {
auto* slice_instr =
const_cast<HloInstruction*>(&slice_adaptor->instruction());
slice_instrs[arg_idx] = slice_instr;

if (!IsContiguousSlice(slice_instr->shape(),
Cast<HloDynamicUpdateSliceInstruction>(slice_instr)
->update()
->shape())) {
return absl::InternalError(
"DynamicAddressComputationFusion only handles contiguous slices "
"currently");
}
}

return GetAllocationSlice(buffer_assignment, &fusion_instr, shape_idx);
}

absl::StatusOr<FusionEmissionResult> EmitGemm(
IrEmitterContext& ir_emitter_context, const HloFusionAdaptor& adaptor,
const HloFusionInstruction& fusion,
Expand All @@ -151,158 +276,67 @@ absl::StatusOr<FusionEmissionResult> EmitGemm(
std::vector<std::optional<uint64_t>> offset_byte_sizes(4, std::nullopt);

std::vector<HloInstruction*> slice_instrs(4, nullptr);
auto get_original_operand_slice =
[&](const HloInstruction* start, const ShapeIndex& index,
unsigned param_idx) -> absl::StatusOr<BufferAllocation::Slice> {
auto slice_adaptor =
HloFindIf({HloInstructionAdaptor(*start)}, adaptor, [](auto node) {
return IsOpcodeAnyOf<HloOpcode::kDynamicSlice, HloOpcode::kSlice>(
node);
});
if (slice_adaptor.has_value()) {
auto* slice_instr =
const_cast<HloInstruction*>(&slice_adaptor->instruction());

if (!IsContiguousSlice(slice_instr->operand(0)->shape(),
slice_instr->shape())) {
return absl::InternalError(
"DynamicAddressComputationFusion only handles contiguous slices "
"currently");
}

slice_instrs[param_idx] = slice_instr;

const auto* param =
Cast<HloParameterInstruction>(slice_instr->operand(0));
TF_ASSIGN_OR_RETURN(
BufferAllocation::Slice orig_slice,
GetAllocationSlice(buffer_assignment,
fusion.operand(param->parameter_number()), index));

if (auto* static_slice = DynCast<HloSliceInstruction>(slice_instr)) {
// Update static slices.
const Shape& src_shape = static_slice->operand(0)->shape();
const Shape& dst_shape = static_slice->shape();
int64_t size = ShapeUtil::ByteSizeOf(dst_shape);

// Given this slice
// f16[1,4,8]{2,1,0} slice(f16[2,8,8]{2,1,0}),
// slice={[1:2], [4:8], [0:8]}
//
// The offset of the slice should be:
// slice_starts(0) * 8 * 8 * sizeof(f16) +
// slice_starts(1) * 8 * sizeof(f16)
int64_t offset = orig_slice.offset();
for (auto [start, stride] :
llvm::zip(static_slice->slice_starts(),
*ShapeUtil::ByteStrides(src_shape))) {
offset += start * stride;
}

return BufferAllocation::Slice(orig_slice.allocation(), offset, size);
}

return orig_slice;
}

const auto* param = DynCast<HloParameterInstruction>(start);
return GetAllocationSlice(buffer_assignment,
fusion.operand(param->parameter_number()), index);
};

auto collect_slice_info = [&](unsigned idx) {
auto* slice_instr =
DynCastOrNull<HloDynamicIndexInstruction>(slice_instrs[idx]);
if (slice_instr == nullptr) {
return;
}

std::vector<BufferAllocation::Slice> offset_slices;
for (auto idx_op : slice_instr->index_operands()) {
const auto* param = Cast<HloParameterInstruction>(idx_op);
offset_slices.push_back(
GetAllocationSlice(buffer_assignment,
fusion.operand(param->parameter_number()),
/*index=*/{})
.value());
}
offset_buffer_indices[idx] = std::move(offset_slices);
orig_shapes[idx] = slice_instr->operand(0)->shape();
sliced_shapes[idx] = DynCast<HloDynamicSliceInstruction>(slice_instr)
? slice_instr->shape()
: slice_instr->operand(1)->shape();
offset_byte_sizes[idx] = ShapeUtil::ByteSizeOfPrimitiveType(
slice_instr->index_operands().front()->shape().element_type());
};

unsigned param_idx = 0;
unsigned arg_idx = 0;
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice lhs_slice,
get_original_operand_slice(custom_call.operand(param_idx),
/*index=*/{}, param_idx));
collect_slice_info(param_idx++);
GetOperandSlice(buffer_assignment, adaptor, fusion,
*custom_call.operand(arg_idx),
slice_instrs, /*shape_idx=*/{}, arg_idx));
TF_RETURN_IF_ERROR(CollectSliceInfo(
buffer_assignment, fusion, absl::Span<HloInstruction*>(slice_instrs),
offset_buffer_indices, orig_shapes, sliced_shapes, offset_byte_sizes,
arg_idx++));

TF_ASSIGN_OR_RETURN(BufferAllocation::Slice rhs_slice,
get_original_operand_slice(custom_call.operand(param_idx),
/*index=*/{}, param_idx));
collect_slice_info(param_idx++);
GetOperandSlice(buffer_assignment, adaptor, fusion,
*custom_call.operand(arg_idx),
slice_instrs, /*shape_idx=*/{}, arg_idx));
TF_RETURN_IF_ERROR(CollectSliceInfo(
buffer_assignment, fusion, absl::Span<HloInstruction*>(slice_instrs),
offset_buffer_indices, orig_shapes, sliced_shapes, offset_byte_sizes,
arg_idx++));

BufferAllocation::Slice output;
std::optional<BufferAllocation::Slice> workspace = std::nullopt;
std::optional<BufferAllocation::Slice> slice_workspace_fake = std::nullopt;

auto get_original_result_slice =
[&](const HloInstruction* start, const ShapeIndex& index,
unsigned param_idx) -> absl::StatusOr<BufferAllocation::Slice> {
auto slice_adaptor = HloFindIf(
{HloInstructionAdaptor(*start)}, adaptor,
[](auto node) {
return node.opcode() == HloOpcode::kDynamicUpdateSlice;
},
false);
if (slice_adaptor.has_value()) {
auto* slice_instr =
const_cast<HloInstruction*>(&slice_adaptor->instruction());
slice_instrs[param_idx] = slice_instr;

if (!IsContiguousSlice(slice_instr->shape(),
Cast<HloDynamicUpdateSliceInstruction>(slice_instr)
->update()
->shape())) {
return absl::InternalError(
"DynamicAddressComputationFusion only handles contiguous slices "
"currently");
}
}

return GetAllocationSlice(buffer_assignment, &fusion, index);
};

// Handling cases where multiple operands share the same buffer, with
// different offset by creating new fake allocations so each operand will have
// a different buffer index. The slices can thus always start at offset 0.
// AddressComputationThunk will take care of the offset adjustment.
std::vector<std::unique_ptr<BufferAllocation>> fake_allocations(4);
if (fusion.shape().IsArray()) {
TF_ASSIGN_OR_RETURN(output, get_original_result_slice(
&custom_call, /*index=*/{}, param_idx));
collect_slice_info(param_idx);
TF_ASSIGN_OR_RETURN(
output, GetResultSlice(buffer_assignment, adaptor, fusion, custom_call,
slice_instrs, /*shape_idx=*/{}, arg_idx));
TF_RETURN_IF_ERROR(CollectSliceInfo(
buffer_assignment, fusion, absl::Span<HloInstruction*>(slice_instrs),
offset_buffer_indices, orig_shapes, sliced_shapes, offset_byte_sizes,
arg_idx));
} else {
TF_ASSIGN_OR_RETURN(
output,
get_original_result_slice(
&custom_call, /*index=*/{kGEMMOutputBufferIndex}, param_idx));
collect_slice_info(param_idx++);
GetResultSlice(buffer_assignment, adaptor, fusion, custom_call,
slice_instrs, /*shape_idx=*/{kGEMMOutputBufferIndex},
arg_idx));
TF_RETURN_IF_ERROR(CollectSliceInfo(
buffer_assignment, fusion, absl::Span<HloInstruction*>(slice_instrs),
offset_buffer_indices, orig_shapes, sliced_shapes, offset_byte_sizes,
arg_idx++));

// TODO(vuson): If we want to support slices of workspace, we'd need to
// start `HloFindIf` with `get-tuple-element` with the right index.
TF_ASSIGN_OR_RETURN(
workspace, GetAllocationSlice(buffer_assignment, &fusion,
/*index=*/{kGEMMWorkspaceBufferIndex}));
collect_slice_info(param_idx);
fake_allocations[param_idx] = std::make_unique<BufferAllocation>(
/*index=*/param_idx, workspace->size(), /*color=*/0);
TF_RETURN_IF_ERROR(CollectSliceInfo(
buffer_assignment, fusion, absl::Span<HloInstruction*>(slice_instrs),
offset_buffer_indices, orig_shapes, sliced_shapes, offset_byte_sizes,
arg_idx));
fake_allocations[arg_idx] = std::make_unique<BufferAllocation>(
/*index=*/arg_idx, workspace->size(), /*color=*/0);
slice_workspace_fake = BufferAllocation::Slice(
fake_allocations[param_idx].get(), 0, workspace->size());
fake_allocations[arg_idx].get(), 0, workspace->size());
}

if (absl::c_all_of(slice_instrs, [&](auto slice_instr) {
Expand All @@ -328,30 +362,30 @@ absl::StatusOr<FusionEmissionResult> EmitGemm(
nullptr;
})) {
// Creating embedded GEMM thunk.
unsigned arg_idx = 0;
unsigned fake_arg_idx = 0;
int64_t lhs_byte_size =
ShapeUtil::ByteSizeOf(custom_call.operand(arg_idx)->shape());
fake_allocations[arg_idx] = std::make_unique<BufferAllocation>(
/*index=*/arg_idx, lhs_byte_size, /*color=*/0);
BufferAllocation::Slice slice_lhs_fake(fake_allocations[arg_idx].get(), 0,
lhs_byte_size);
ShapeUtil::ByteSizeOf(custom_call.operand(fake_arg_idx)->shape());
fake_allocations[fake_arg_idx] = std::make_unique<BufferAllocation>(
/*index=*/fake_arg_idx, lhs_byte_size, /*color=*/0);
BufferAllocation::Slice slice_lhs_fake(fake_allocations[fake_arg_idx].get(),
0, lhs_byte_size);

arg_idx++;
fake_arg_idx++;
int64_t rhs_byte_size =
ShapeUtil::ByteSizeOf(custom_call.operand(arg_idx)->shape());
fake_allocations[arg_idx] = std::make_unique<BufferAllocation>(
/*index=*/arg_idx, rhs_byte_size, /*color=*/0);
BufferAllocation::Slice slice_rhs_fake(fake_allocations[arg_idx].get(), 0,
rhs_byte_size);
ShapeUtil::ByteSizeOf(custom_call.operand(fake_arg_idx)->shape());
fake_allocations[fake_arg_idx] = std::make_unique<BufferAllocation>(
/*index=*/fake_arg_idx, rhs_byte_size, /*color=*/0);
BufferAllocation::Slice slice_rhs_fake(fake_allocations[fake_arg_idx].get(),
0, rhs_byte_size);

arg_idx++;
fake_arg_idx++;
int64_t out_fake_byte_size = ShapeUtil::ByteSizeOf(
custom_call.shape().IsArray() ? custom_call.shape()
: custom_call.shape().tuple_shapes(0));
fake_allocations[arg_idx] = std::make_unique<BufferAllocation>(
/*index=*/arg_idx, out_fake_byte_size, /*color=*/0);
BufferAllocation::Slice slice_out_fake(fake_allocations[arg_idx].get(), 0,
out_fake_byte_size);
fake_allocations[fake_arg_idx] = std::make_unique<BufferAllocation>(
/*index=*/fake_arg_idx, out_fake_byte_size, /*color=*/0);
BufferAllocation::Slice slice_out_fake(fake_allocations[fake_arg_idx].get(),
0, out_fake_byte_size);
ThunkSequence seq;
seq.emplace_back(std::make_unique<GemmThunk>(
thunk_info, std::move(config), slice_lhs_fake, slice_rhs_fake,
Expand Down

0 comments on commit 92b03bd

Please sign in to comment.