diff --git a/xla/service/spmd/gather_scatter_handler.cc b/xla/service/spmd/gather_scatter_handler.cc index ecf06378a266cc..57f13ca7d1c5fb 100644 --- a/xla/service/spmd/gather_scatter_handler.cc +++ b/xla/service/spmd/gather_scatter_handler.cc @@ -193,6 +193,44 @@ std::vector GatherOutputDimsByPriority( return priority_dims_for_output; } +PartitionedHlo ClampGatherIndices(const PartitionedHlo& indices, + const Shape& operand_base_shape, + absl::Span start_index_map, + int64_t index_vector_dim, SpmdBuilder* b) { + const PrimitiveType indices_type = indices.hlo()->shape().element_type(); + + HloInstruction* max_indices; + if (index_vector_dim < indices.rank()) { + std::vector max_indices_values; + max_indices_values.reserve(start_index_map.size()); + for (int64_t operand_dim : start_index_map) { + max_indices_values.push_back(operand_base_shape.dimensions(operand_dim) - + 1); + } + max_indices = b->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1(max_indices_values))); + max_indices = b->AddInstruction(HloInstruction::CreateBroadcast( + indices.hlo()->shape(), max_indices, {index_vector_dim})); + } else { + CHECK_EQ(start_index_map.size(), 1); + max_indices = CreateR0WithType( + indices_type, operand_base_shape.dimensions(start_index_map[0]) - 1, b); + max_indices = b->AddInstruction(HloInstruction::CreateBroadcast( + indices.hlo()->shape(), max_indices, {})); + } + + HloInstruction* constant_zero = CreateR0WithType(indices_type, 0, b); + HloInstruction* min_indices = + b->AddInstruction(HloInstruction::CreateBroadcast(indices.hlo()->shape(), + constant_zero, {})); + + HloInstruction* clamped_indices = b->AddInstruction( + HloInstruction::CreateTernary(indices.hlo()->shape(), HloOpcode::kClamp, + min_indices, indices.hlo(), max_indices)); + clamped_indices->set_sharding(indices.sharding()); + return PartitionedHlo(clamped_indices, indices.base_shape(), indices.state()); +} + // Returns the min and max for the indices in a scatter/gather which has the // operand partitioned on trivial slice dimensions (slice size 1). std::pair @@ -451,11 +489,9 @@ absl::StatusOr PartitionGatherTrivialSlicedOperandDimensions( SpmdBuilder* b = visitor->builder(); const GatherDimensionNumbers& dnums = gather->gather_dimension_numbers(); - std::vector start_index_map(dnums.start_index_map().begin(), - dnums.start_index_map().end()); if (std::optional> trivial_slice_dims = GatherScatterOperandPartitionedOnTrivialSliceDims( - operand, start_index_map, slice_sizes)) { + operand, dnums.start_index_map(), slice_sizes)) { const HloSharding original_operand_sharding = operand.sharding(); const int64_t num_groups = operand.sharding().NumTiles(*trivial_slice_dims); const int64_t num_tiles = operand.sharding().TotalNumTiles(); @@ -504,6 +540,9 @@ absl::StatusOr PartitionGatherTrivialSlicedOperandDimensions( // Reshard indices to its intended sharding before clamping and adjusting. indices = indices.Reshard(hlo_sharding_util::UngroupSharding(indices_grouped)); + indices = ClampGatherIndices(indices, operand.base_shape(), + dnums.start_index_map(), + dnums.index_vector_dim(), b); // Now the operand is partitioned in trivial slice dimensions, and the // indices are replicated. We execute a gather on partitioned operand, // with full number of indices, where out-of-bounds indices are clamped, @@ -514,8 +553,9 @@ absl::StatusOr PartitionGatherTrivialSlicedOperandDimensions( HloInstruction* indices_max; std::tie(indices_min, indices_max) = IndexBoundsForGatherScatterOperandPartitionedOnTrivialSliceDims( - operand, indices, operand.state().partition_id, start_index_map, - *trivial_slice_dims, dnums.index_vector_dim(), b); + operand, indices, operand.state().partition_id, + dnums.start_index_map(), *trivial_slice_dims, + dnums.index_vector_dim(), b); // Clamp the indices. auto adjusted_indices = b->AddInstruction( HloInstruction::CreateTernary(indices.hlo()->shape(), HloOpcode::kClamp, diff --git a/xla/service/spmd/spmd_partitioner_test.cc b/xla/service/spmd/spmd_partitioner_test.cc index 59b7cce5432c8c..c95573abba52f2 100644 --- a/xla/service/spmd/spmd_partitioner_test.cc +++ b/xla/service/spmd/spmd_partitioner_test.cc @@ -7926,10 +7926,13 @@ ENTRY entry { auto min = AllOf(op::Broadcast(offset), op::Shape("s32[2,3]")); auto max = AllOf(op::Broadcast(op::Add(offset, op::Constant())), op::Shape("s32[2,3]")); - auto clamp = op::Clamp(min, op::Parameter(1), max); + auto clamped_indices = + op::Clamp(op::Broadcast(op::Constant()), op::Parameter(1), + op::Broadcast(op::Constant())); + auto clamp = op::Clamp(min, clamped_indices, max); auto gather = op::Gather(op::Parameter(0), op::Subtract(clamp, min)); auto mask = - op::Or(op::Lt(op::Parameter(1), min), op::Gt(op::Parameter(1), max)); + op::Or(op::Lt(clamped_indices, min), op::Gt(clamped_indices, max)); auto masked = op::Select(op::Broadcast(mask), op::Broadcast(op::Constant()), gather); HloInstruction* root = module->entry_computation()->root_instruction(); @@ -7952,15 +7955,18 @@ ENTRY entry { TF_ASSERT_OK_AND_ASSIGN(auto module, PartitionComputation(hlo_string, /*num_devices=*/4)); VLOG(1) << module->ToString(); + auto clamped_indices = + op::Clamp(op::Broadcast(op::Constant()), op::Parameter(1), + op::Broadcast(op::Constant())); auto offset = op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId())); auto min = AllOf(op::Broadcast(offset), op::Shape("s32[2,3]")); auto max = AllOf(op::Broadcast(op::Add(offset, op::Constant())), op::Shape("s32[2,3]")); - auto clamp = op::Clamp(min, op::Parameter(1), max); + auto clamp = op::Clamp(min, clamped_indices, max); auto gather = op::Gather(op::Parameter(0), op::Subtract(clamp, min)); auto mask = - op::Or(op::Lt(op::Parameter(1), min), op::Gt(op::Parameter(1), max)); + op::Or(op::Lt(clamped_indices, min), op::Gt(clamped_indices, max)); auto masked = op::Select(op::Broadcast(mask), op::Broadcast(op::Constant()), gather); HloInstruction* root = module->entry_computation()->root_instruction(); @@ -11919,11 +11925,10 @@ ENTRY entry { VLOG(1) << module->ToString(); HloInstruction* root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, op::AllReduce(op::Select(_, _, op::Gather(_, _)))); - EXPECT_THAT(root->operand(0)->operand(2)->operand(1), - op::Subtract(op::Clamp(_, op::Parameter(1), _), _)); + EXPECT_THAT( + root->operand(0)->operand(2)->operand(1), + op::Subtract(op::Clamp(_, op::Clamp(_, op::Parameter(1), _), _), _)); - auto clamp = FindInstruction(module.get(), HloOpcode::kClamp); - EXPECT_THAT(clamp->operand(1), op::Parameter(1)); auto dynamic_slice = FindInstruction(module.get(), HloOpcode::kDynamicSlice); EXPECT_THAT(dynamic_slice->operand(1), op::PartitionId()); auto collective_permute = @@ -11955,8 +11960,9 @@ ENTRY entry { _, op::AllReduce(op::Select(_, _, op::Gather(op::AllReduce(_), _))), _, _, _))); auto gather = FindInstruction(module.get(), HloOpcode::kGather); - EXPECT_THAT(gather->operand(1), - op::Subtract(op::Clamp(_, op::Parameter(1), _), _)); + EXPECT_THAT( + gather->operand(1), + op::Subtract(op::Clamp(_, op::Clamp(_, op::Parameter(1), _), _), _)); auto collective_permute = FindInstruction(module.get(), HloOpcode::kCollectivePermute); EXPECT_NE(collective_permute, nullptr);