diff --git a/xla/hlo/analysis/while_loop_analysis_test.cc b/xla/hlo/analysis/while_loop_analysis_test.cc index 63af90a28e611..4bf4dbec14342 100644 --- a/xla/hlo/analysis/while_loop_analysis_test.cc +++ b/xla/hlo/analysis/while_loop_analysis_test.cc @@ -301,11 +301,11 @@ bool RangeEqualIgnoreBitwidth(const Range& range, int init, int limit, : r.min().GetUnsignedValue(); }; auto range_max = [](const Range& r) { - return r.min().IsSigned() ? r.max().GetSignedValue() - : r.max().GetUnsignedValue(); + return r.max()->IsSigned() ? r.max()->GetSignedValue() + : r.max()->GetUnsignedValue(); }; return range_min(range) == init && range_max(range) == limit && - range.step().GetSignedValue() == step; + range.step()->GetSignedValue() == step; } TEST_F(WhileLoopAnalysisTest, ExactBoundTrivialRange) { diff --git a/xla/service/collective_pipeliner.cc b/xla/service/collective_pipeliner.cc index d02424990edea..69a4af5295c5f 100644 --- a/xla/service/collective_pipeliner.cc +++ b/xla/service/collective_pipeliner.cc @@ -148,7 +148,7 @@ std::optional GetSlicedDimension( bool CheckIndexIsMonotonic( const HloInstruction* index, - const absl::flat_hash_map& induction_map) { + absl::flat_hash_map& induction_map) { // Because the only math operations supported by RecursivelyIdentifyRange() // are only sub/add then checking that we can compute the range here is enough // to guarantee that the index is monotonic if the base index is monotonic. If @@ -156,7 +156,7 @@ bool CheckIndexIsMonotonic( // sophisticated check for monotonicity. Range range = RecursivelyIdentifyRange(index, induction_map); VLOG(6) << "Range for: " << index->ToString() << " " << range.ToString(); - return !range.IsEmpty() && range.IsLinear(); + return !range.IsEmpty() && range.IsBounded() && range.IsLinear(); } // Check that the parameter is only used in a pattern param -> gte -> @@ -789,8 +789,7 @@ class WhileLoopAnalysis { CollectivePipeliner::PipeliningDirection direction, int64_t level_to_operate_on, const absl::flat_hash_map& parameter_gtes_count, - const absl::flat_hash_map& index_ranges) - const; + absl::flat_hash_map& index_ranges) const; // Merges the new collective (instr) with the existing one stored in // move_infos_[indices_to_merge[0]]. indices_to_merge.size() should be 1. @@ -981,8 +980,7 @@ WhileLoopAnalysis::IsSupportedDynamicUpdateSlice( CollectivePipeliner::PipeliningDirection direction, int64_t level_to_operate_on, const absl::flat_hash_map& parameter_gtes_count, - const absl::flat_hash_map& index_ranges) - const { + absl::flat_hash_map& index_ranges) const { HloComputation* while_body = while_->while_body(); const HloInstruction* loop_parameter = while_body->parameter_instructions()[0]; diff --git a/xla/service/value_range.cc b/xla/service/value_range.cc index 0bdf42ae090b6..d4edd39db8edd 100644 --- a/xla/service/value_range.cc +++ b/xla/service/value_range.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include "absl/container/flat_hash_map.h" #include "absl/log/log.h" @@ -54,7 +55,8 @@ std::string Range::ToString() const { return min_.ToString(); } return absl::StrCat( - "min: ", min_.ToString(), " max: ", max_.ToString(), + "min: ", min_.ToString(), + " max: ", IsBounded() ? max_.value().ToString() : "Unknown", " step: ", IsStepKnown() ? step_.value().ToString() : "Unknown"); } @@ -69,17 +71,27 @@ std::optional FindStepForBinaryOp(const Range& lhs, if (rhs.IsSingleValue()) { return lhs.step(); } - if (lhs.step().eq(rhs.step())) { + if (lhs.step()->eq(rhs.step().value())) { return lhs.step(); } return std::nullopt; } +// Helper function that updates the known_ranges map and returns the range. +Range RecordAndReturnRange( + const Range& range, const HloInstruction* instr, + absl::flat_hash_map& known_ranges) { + known_ranges[instr] = range; + VLOG(5) << "Computed range for: " << instr->name() << " -> " + << range.ToString(); + return range; +} + // Identify the value ranges of a scalar HLO with a integer type. It returns // a range of values that the instruction can have. Range RecursivelyIdentifyRange( const HloInstruction* instr, - const absl::flat_hash_map& predefined_ranges, + absl::flat_hash_map& known_ranges, const HloAliasAnalysis* alias_analysis) { // Non scalar or non-integer HLO. Abort. if ((!instr->shape().IsInteger() && instr->shape().element_type() != PRED) || @@ -87,32 +99,48 @@ Range RecursivelyIdentifyRange( return Range{}; } VLOG(5) << "Computing Range for " << instr->ToString(); - auto it = predefined_ranges.find(instr); - if (it != predefined_ranges.end()) { - VLOG(5) << "Found range! " << it->second.max().GetSignedValue() << " " - << it->second.min().GetSignedValue(); + auto it = known_ranges.find(instr); + if (it != known_ranges.end()) { + VLOG(5) << "Found range: " << it->second.ToString(); return it->second; } else if (alias_analysis != nullptr) { auto value_set = alias_analysis->dataflow_analysis().GetFlattenedValueSet(instr); for (const auto& value : value_set.TakeValues()) { for (const HloPosition& position : value->positions()) { - auto it = predefined_ranges.find(position.instruction); - if (it != predefined_ranges.end()) { - VLOG(5) << "Found range in defining instruction! " - << it->second.max().GetSignedValue() << " " - << it->second.min().GetSignedValue(); + auto it = known_ranges.find(position.instruction); + if (it != known_ranges.end()) { + VLOG(5) << "Found range in defining instruction: " + << it->second.ToString(); return it->second; } } } } switch (instr->opcode()) { + case HloOpcode::kGetTupleElement: { + if (alias_analysis != nullptr) { + auto value_set = + alias_analysis->dataflow_analysis().GetFlattenedValueSet(instr); + std::vector values = value_set.TakeValues(); + if (values.size() != 1) { + VLOG(5) << "Ambiguous value set"; + return Range{}; + } + HloInstruction* defining_instruction = + values.at(0)->defining_instruction(); + if (defining_instruction != nullptr) { + return RecursivelyIdentifyRange(defining_instruction, known_ranges, + alias_analysis); + } + } + return Range{}; + } case HloOpcode::kCompare: { VLOG(5) << "Handling Compare"; - Range lhs = RecursivelyIdentifyRange(instr->operand(0), predefined_ranges, + Range lhs = RecursivelyIdentifyRange(instr->operand(0), known_ranges, alias_analysis); - Range rhs = RecursivelyIdentifyRange(instr->operand(1), predefined_ranges, + Range rhs = RecursivelyIdentifyRange(instr->operand(1), known_ranges, alias_analysis); VLOG(5) << "Returned Rhs: " << rhs.ToString() << " Lhs: " << lhs.ToString(); @@ -120,37 +148,37 @@ Range RecursivelyIdentifyRange( if (instr->comparison_direction() != ComparisonDirection::kLt) { return Range{}; } - if (lhs.max().lt(rhs.min())) { - return Range{ConstantValue::GetOne(/*bitwidth=*/1, /*is_signed=*/false), - ConstantValue::GetOne(/*bitwidth=*/1, /*is_signed=*/false), - /*is_linear=*/true}; + if (lhs.IsBounded() && lhs.max()->lt(rhs.min())) { + return RecordAndReturnRange( + Range{ConstantValue::GetOne(/*bitwidth=*/1, /*is_signed=*/false), + ConstantValue::GetOne(/*bitwidth=*/1, /*is_signed=*/false), + /*is_linear=*/true}, + instr, known_ranges); } - if (!lhs.min().lt(rhs.max())) { - return Range{ - ConstantValue::GetZero(/*bitwidth=*/1, /*is_signed=*/false), - ConstantValue::GetZero(/*bitwidth=*/1, /*is_signed=*/false), - /*is_linear=*/true}; + if (rhs.IsBounded() && !lhs.min().lt(rhs.max().value())) { + return RecordAndReturnRange( + Range{ConstantValue::GetZero(/*bitwidth=*/1, /*is_signed=*/false), + ConstantValue::GetZero(/*bitwidth=*/1, /*is_signed=*/false), + /*is_linear=*/true}, + instr, known_ranges); } - VLOG(5) << "Compare failed"; - VLOG(5) << "rhs max " << rhs.max().GetSignedValue() << " rhs min " - << rhs.min().GetSignedValue() << " lhs max " - << lhs.max().GetSignedValue() << " lhs min " - << lhs.min().GetSignedValue(); return Range{}; } case HloOpcode::kConstant: { if (instr->shape().element_type() == PRED && instr->shape().dimensions_size() == 0) { if (instr->literal().IsAll(true)) { - return Range{ - ConstantValue::GetOne(/*bitwidth=*/1, /*is_signed=*/false), - ConstantValue::GetOne(/*bitwidth=*/1, /*is_signed=*/false), - /*is_linear=*/true}; + return RecordAndReturnRange( + Range{ConstantValue::GetOne(/*bitwidth=*/1, /*is_signed=*/false), + ConstantValue::GetOne(/*bitwidth=*/1, /*is_signed=*/false), + /*is_linear=*/true}, + instr, known_ranges); } - return Range{ - ConstantValue::GetZero(/*bitwidth=*/1, /*is_signed=*/false), - ConstantValue::GetZero(/*bitwidth=*/1, /*is_signed=*/false), - /*is_linear=*/true}; + return RecordAndReturnRange( + Range{ConstantValue::GetZero(/*bitwidth=*/1, /*is_signed=*/false), + ConstantValue::GetZero(/*bitwidth=*/1, /*is_signed=*/false), + /*is_linear=*/true}, + instr, known_ranges); } if (!instr->shape().IsInteger()) { return Range{}; @@ -162,25 +190,29 @@ Range RecursivelyIdentifyRange( primitive_util::IsSignedIntegralType(instr->shape().element_type()); if (is_signed) { const int64_t value = *instr->literal().GetFirstInteger(); - return Range{ConstantValue::GetSigned(value, bitwidth), - ConstantValue::GetSigned(value, bitwidth), - ConstantValue::GetOne(/*bitwidth=*/1, /*is_signed=*/false), - /*is_linear=*/true}; + return RecordAndReturnRange( + Range{ConstantValue::GetSigned(value, bitwidth), + ConstantValue::GetSigned(value, bitwidth), + ConstantValue::GetOne(/*bitwidth=*/1, /*is_signed=*/false), + /*is_linear=*/true}, + instr, known_ranges); } const uint64_t value = *instr->literal().GetFirstInteger(); - return Range{ConstantValue::GetUnsigned(value, bitwidth), - ConstantValue::GetUnsigned(value, bitwidth), - ConstantValue::GetOne(/*bitwidth=*/1, /*is_signed=*/false), - /*is_linear=*/true}; + return RecordAndReturnRange( + Range{ConstantValue::GetUnsigned(value, bitwidth), + ConstantValue::GetUnsigned(value, bitwidth), + ConstantValue::GetOne(/*bitwidth=*/1, /*is_signed=*/false), + /*is_linear=*/true}, + instr, known_ranges); } case HloOpcode::kAdd: { if (!instr->shape().IsInteger()) { return Range{}; } VLOG(5) << "Handling Add"; - Range lhs = RecursivelyIdentifyRange(instr->operand(0), predefined_ranges, + Range lhs = RecursivelyIdentifyRange(instr->operand(0), known_ranges, alias_analysis); - Range rhs = RecursivelyIdentifyRange(instr->operand(1), predefined_ranges, + Range rhs = RecursivelyIdentifyRange(instr->operand(1), known_ranges, alias_analysis); VLOG(5) << "Returned Rhs: " << rhs.ToString() << " Lhs: " << lhs.ToString(); @@ -188,22 +220,29 @@ Range RecursivelyIdentifyRange( return Range{}; } ConstantValue min = lhs.min().add(rhs.min()); - ConstantValue max = lhs.max().add(rhs.max()); - if (max.lt(min)) { - VLOG(5) << "Add wrapped"; - return Range{}; + std::optional step = FindStepForBinaryOp(lhs, rhs); + if (lhs.IsBounded() && rhs.IsBounded()) { + ConstantValue max = lhs.max()->add(rhs.max().value()); + if (max.lt(min)) { + VLOG(5) << "Add wrapped"; + return Range{}; + } + return RecordAndReturnRange( + Range{min, max, step, lhs.IsLinear() && rhs.IsLinear()}, instr, + known_ranges); } - return Range{min, max, FindStepForBinaryOp(lhs, rhs), - lhs.IsLinear() && rhs.IsLinear()}; + return RecordAndReturnRange( + Range{min, std::nullopt, step, lhs.IsLinear() && rhs.IsLinear()}, + instr, known_ranges); } case HloOpcode::kMultiply: { if (!instr->shape().IsInteger()) { return Range{}; } VLOG(5) << "Handling Multiply"; - Range lhs = RecursivelyIdentifyRange(instr->operand(0), predefined_ranges, + Range lhs = RecursivelyIdentifyRange(instr->operand(0), known_ranges, alias_analysis); - Range rhs = RecursivelyIdentifyRange(instr->operand(1), predefined_ranges, + Range rhs = RecursivelyIdentifyRange(instr->operand(1), known_ranges, alias_analysis); VLOG(5) << "Returned Rhs: " << rhs.ToString() << " Lhs: " << lhs.ToString(); @@ -219,52 +258,84 @@ Range RecursivelyIdentifyRange( // When multiplying with a constant, min, max, and step are all // multiplied by the single value. ConstantValue min = operand_range.min().mul(single_value); - ConstantValue max = operand_range.max().mul(single_value); + if (operand_range.IsBounded()) { + ConstantValue max = operand_range.max()->mul(single_value); + if (!operand_range.IsStepKnown()) { + return RecordAndReturnRange(Range{min, max, operand_range.IsLinear()}, + instr, known_ranges); + } + ConstantValue step = operand_range.step()->mul(single_value); + return RecordAndReturnRange( + Range{min, max, step, operand_range.IsLinear()}, instr, + known_ranges); + } if (!operand_range.IsStepKnown()) { - return Range{min, max, operand_range.IsLinear()}; + return RecordAndReturnRange( + Range{min, std::nullopt, operand_range.IsLinear()}, instr, + known_ranges); } - ConstantValue step = operand_range.step().mul(single_value); - return Range{min, max, step, operand_range.IsLinear()}; + ConstantValue step = operand_range.step()->mul(single_value); + return RecordAndReturnRange( + Range{min, std::nullopt, step, operand_range.IsLinear()}, instr, + known_ranges); } case HloOpcode::kSelect: { VLOG(5) << "Handling Select: " << instr->ToString(); const HloInstruction* cmp = instr->operand(0); Range cmp_range = - RecursivelyIdentifyRange(cmp, predefined_ranges, alias_analysis); + RecursivelyIdentifyRange(cmp, known_ranges, alias_analysis); // Support only when the select has a constant value as condition. if (cmp_range.IsEmpty() || !cmp_range.IsSingleValue()) { VLOG(5) << "Select failed"; return Range{}; } if (cmp_range.GetSingleSignedValue() == 0) { - return RecursivelyIdentifyRange(instr->operand(2), predefined_ranges, - alias_analysis); + return RecordAndReturnRange( + RecursivelyIdentifyRange(instr->operand(2), known_ranges, + alias_analysis), + instr, known_ranges); } - return RecursivelyIdentifyRange(instr->operand(1), predefined_ranges, - alias_analysis); + return RecordAndReturnRange( + RecursivelyIdentifyRange(instr->operand(1), known_ranges, + alias_analysis), + instr, known_ranges); } case HloOpcode::kSubtract: { if (!instr->shape().IsInteger()) { return Range{}; } VLOG(5) << "Handling Subtract"; - Range lhs = RecursivelyIdentifyRange(instr->operand(0), predefined_ranges, + Range lhs = RecursivelyIdentifyRange(instr->operand(0), known_ranges, alias_analysis); - Range rhs = RecursivelyIdentifyRange(instr->operand(1), predefined_ranges, + Range rhs = RecursivelyIdentifyRange(instr->operand(1), known_ranges, alias_analysis); VLOG(5) << "Returned Rhs: " << rhs.ToString() << " Lhs: " << lhs.ToString(); if (lhs.IsEmpty() || rhs.IsEmpty()) { return Range{}; } - ConstantValue min = lhs.min().sub(rhs.max()); - ConstantValue max = lhs.max().sub(rhs.min()); - if (max.lt(min)) { - VLOG(5) << "Subtract wrapped"; + if (lhs.IsBounded() && rhs.IsBounded()) { + ConstantValue min = lhs.min().sub(rhs.max().value()); + ConstantValue max = lhs.max()->sub(rhs.min()); + if (max.lt(min)) { + VLOG(5) << "Subtract wrapped"; + return Range{}; + } + return RecordAndReturnRange( + Range{min, max, FindStepForBinaryOp(lhs, rhs), + lhs.IsLinear() && rhs.IsLinear()}, + instr, known_ranges); + } else if (lhs.IsBounded()) { // bounded - unbounded -> Empty range + VLOG(5) << "Subtract unbounded from bounded is not represntable with a " + "range"; return Range{}; + } else { // unbounded - bounded -> Unbounded range + ConstantValue min = lhs.min().sub(rhs.max().value()); + return RecordAndReturnRange( + Range{min, std::nullopt, FindStepForBinaryOp(lhs, rhs), + lhs.IsLinear() && rhs.IsLinear()}, + instr, known_ranges); } - return Range{min, max, FindStepForBinaryOp(lhs, rhs), - lhs.IsLinear() && rhs.IsLinear()}; } default: break; diff --git a/xla/service/value_range.h b/xla/service/value_range.h index b46b9bbcfa22f..eb06d3b488ffd 100644 --- a/xla/service/value_range.h +++ b/xla/service/value_range.h @@ -26,7 +26,10 @@ limitations under the License. namespace xla { -// Class keeping track of the range of an HLO value. +// Class keeping track of the range of an HLO value. A range is typically +// defined by a minimum value, a maximum value, and a step value. The step and +// maximum values are optional. If the maximum value is missing, the range is +// unbounded. The default step value is nullopt. class Range { public: Range() @@ -35,13 +38,14 @@ class Range { step_(ConstantValue::GetZero(/*bitwidth=*/64, /*is_signed=*/false)), empty_(true), is_linear_(false) {} - Range(const ConstantValue& min, const ConstantValue& max, bool is_linear) + Range(const ConstantValue& min, std::optional max, + bool is_linear) : min_(min), max_(max), step_(std::nullopt), empty_(false), is_linear_(is_linear) {} - Range(const ConstantValue& min, const ConstantValue& max, + Range(const ConstantValue& min, std::optional max, std::optional step, bool is_linear) : min_(min), max_(max), @@ -51,13 +55,15 @@ class Range { // Minimum value of the range. const ConstantValue& min() const { return min_; } // Maximum value of the range. - const ConstantValue& max() const { return max_; } + const std::optional& max() const { return max_; } // Step value of the range. - const ConstantValue& step() const { return step_.value(); } - // Returns if the range is empty (no value in set). + const std::optional& step() const { return step_; } + // Returns if the range has min and max values (it can be a single value). bool IsEmpty() const { return empty_; } // Only one value in set. This means the range is a constant. - bool IsSingleValue() const { return !IsEmpty() && min_ == max_; } + bool IsSingleValue() const { + return !IsEmpty() && max_.has_value() && min_ == max_; + } // This is a way to track in some way recurring values that change in a // monotonic way. This true means that the variables driving the range change // in a monotonic way and that the way they are composed together is linear @@ -65,6 +71,8 @@ class Range { // loop recursion. bool IsLinear() const { return is_linear_; } bool IsStepKnown() const { return step_.has_value(); } + // If this range is a bounded range with known max value. + bool IsBounded() const { return max_.has_value(); } // If this range represents a single value return that signed value. std::optional GetSingleSignedValue() const; // If this range represents a single value return that unsigned value. @@ -81,20 +89,20 @@ class Range { private: ConstantValue min_; - ConstantValue max_; + std::optional max_; std::optional step_; bool empty_; bool is_linear_; }; -// Constructs a Range object from a HloInstruction. Gets a "predefined_ranges" +// Constructs a Range object from a HloInstruction. Gets a "known_ranges" // object as input that returns known ranges for some variables for which we // already know the range. The final range is composed from operations over // these predetermined ranges. // The input HLO needs to be of scalar type and integer. Range RecursivelyIdentifyRange( const HloInstruction* instr, - const absl::flat_hash_map& predefined_ranges, + absl::flat_hash_map& known_ranges, const HloAliasAnalysis* alias_analysis = nullptr); } // namespace xla diff --git a/xla/service/value_range_test.cc b/xla/service/value_range_test.cc index 0b83a374e5da0..ff389b92b11c5 100644 --- a/xla/service/value_range_test.cc +++ b/xla/service/value_range_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/service/value_range.h" +#include #include #include @@ -22,6 +23,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/strings/string_view.h" #include "xla/hlo/analysis/hlo_alias_analysis.h" +#include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/parser/hlo_parser.h" #include "xla/service/constant_value.h" @@ -59,8 +61,8 @@ TEST_F(ValueRangeTest, AddedValue) { EXPECT_FALSE(range.IsSingleValue()); EXPECT_TRUE(range.IsLinear()); EXPECT_EQ(range.min().GetSignedValue(), 124); - EXPECT_EQ(range.max().GetSignedValue(), 124 + 5); - EXPECT_EQ(range.step().GetSignedValue(), 1); + EXPECT_EQ(range.max()->GetSignedValue(), 124 + 5); + EXPECT_EQ(range.step()->GetSignedValue(), 1); } TEST_F(ValueRangeTest, MultiplyValue) { @@ -89,8 +91,53 @@ TEST_F(ValueRangeTest, MultiplyValue) { EXPECT_FALSE(range.IsSingleValue()); EXPECT_TRUE(range.IsLinear()); EXPECT_EQ(range.min().GetSignedValue(), 0); - EXPECT_EQ(range.max().GetSignedValue(), 32 * 1024); - EXPECT_EQ(range.step().GetSignedValue(), 2 * 1024); + EXPECT_EQ(range.max()->GetSignedValue(), 32 * 1024); + EXPECT_EQ(range.step()->GetSignedValue(), 2 * 1024); +} + +TEST_F(ValueRangeTest, MultiplyValuePassedToLoop) { + constexpr absl::string_view hlo_string = R"( + HloModule module + body.comp { + p0 = (s32[], s32[]) parameter(0) + gte = s32[] get-tuple-element(p0), index=0 + ROOT tuple = (s32[], s32[]) tuple(gte, gte) + } + cond.comp { + p0 = (s32[], s32[]) parameter(0) + ROOT out = pred[] constant(true) + } + ENTRY entry { + c0 = s32[] constant(1024) + p0 = s32[] parameter(0) + %mul = s32[] multiply(p0, c0) + tuple = (s32[], s32[]) tuple(%mul, %mul) + ROOT out = (s32[], s32[]) while(tuple), condition=cond.comp, + body=body.comp + } + )"; + auto module = + ParseAndReturnUnverifiedModule(hlo_string, HloModuleConfig{}).value(); + TF_ASSERT_OK_AND_ASSIGN(auto alias_analysis, + HloAliasAnalysis::Run(module.get())); + const HloInstruction* p0 = + module->entry_computation()->parameter_instruction(0); + absl::flat_hash_map fs; + // p0 has range min = 0, max = 32, step = 2. + fs.insert(std::make_pair( + p0, Range{/*min=*/ConstantValue::GetSigned(0, /*bitwidth=*/32), + /*max=*/ConstantValue::GetSigned(32, /*bitwidth=*/32), + /*step=*/ConstantValue::GetUnsigned(2, /*bitwidth=*/32), + /*is_linear=*/true})); + HloComputation* body = module->GetComputationWithName("body.comp"); + HloInstruction* gte = body->GetInstructionWithName("gte"); + auto range = RecursivelyIdentifyRange(gte, fs, alias_analysis.get()); + EXPECT_FALSE(range.IsEmpty()); + EXPECT_FALSE(range.IsSingleValue()); + EXPECT_TRUE(range.IsLinear()); + EXPECT_EQ(range.min().GetSignedValue(), 0); + EXPECT_EQ(range.max()->GetSignedValue(), 32 * 1024); + EXPECT_EQ(range.step()->GetSignedValue(), 2 * 1024); } TEST_F(ValueRangeTest, ConstantValuePred) { @@ -105,14 +152,15 @@ TEST_F(ValueRangeTest, ConstantValuePred) { auto module = ParseAndReturnUnverifiedModule(hlo_string, HloModuleConfig{}).value(); const HloInstruction* tuple = module->entry_computation()->root_instruction(); - auto false_range = RecursivelyIdentifyRange(tuple->operand(0), {}); + absl::flat_hash_map known_ranges; + auto false_range = RecursivelyIdentifyRange(tuple->operand(0), known_ranges); VLOG(3) << "false_range: " << false_range.ToString(); EXPECT_FALSE(false_range.IsEmpty()); EXPECT_TRUE(false_range.IsSingleValue()); EXPECT_TRUE(false_range.IsLinear()); EXPECT_EQ(false_range.min().GetUnsignedValue(), 0); - auto true_range = RecursivelyIdentifyRange(tuple->operand(1), {}); + auto true_range = RecursivelyIdentifyRange(tuple->operand(1), known_ranges); VLOG(3) << "true_range: " << true_range.ToString(); EXPECT_FALSE(true_range.IsEmpty()); EXPECT_TRUE(true_range.IsSingleValue()); @@ -138,7 +186,8 @@ TEST_F(ValueRangeTest, ConstantValueWithConditional) { ENTRY entry { p0 = s32[] parameter(0) branch_index = s32[] parameter(1) - ROOT conditional.1 = (s32[], s32[]) conditional(branch_index, p0, p0), branch_computations={region1, region2} + ROOT conditional.1 = (s32[], s32[]) conditional(branch_index, p0, p0), + branch_computations={region1, region2} } )"; auto module = @@ -164,16 +213,16 @@ TEST_F(ValueRangeTest, ConstantValueWithConditional) { EXPECT_FALSE(add_range.IsSingleValue()); EXPECT_TRUE(add_range.IsLinear()); EXPECT_EQ(add_range.min().GetSignedValue(), 1024); - EXPECT_EQ(add_range.max().GetSignedValue(), 1024 + 32); - EXPECT_EQ(add_range.step().GetSignedValue(), 2); + EXPECT_EQ(add_range.max()->GetSignedValue(), 1024 + 32); + EXPECT_EQ(add_range.step()->GetSignedValue(), 2); auto mult_range = RecursivelyIdentifyRange(mult, fs, alias_analysis.get()); EXPECT_FALSE(mult_range.IsEmpty()); EXPECT_FALSE(mult_range.IsSingleValue()); EXPECT_TRUE(mult_range.IsLinear()); EXPECT_EQ(mult_range.min().GetSignedValue(), 0); - EXPECT_EQ(mult_range.max().GetSignedValue(), 32 * 1024); - EXPECT_EQ(mult_range.step().GetSignedValue(), 2 * 1024); + EXPECT_EQ(mult_range.max()->GetSignedValue(), 32 * 1024); + EXPECT_EQ(mult_range.step()->GetSignedValue(), 2 * 1024); } TEST_F(ValueRangeTest, SelectValueWithCompareInConditional) { @@ -183,28 +232,29 @@ TEST_F(ValueRangeTest, SelectValueWithCompareInConditional) { region1_param = s32[] parameter(0) region1_c0 = s32[] constant(1024) %add = s32[] add(region1_param, region1_c0) - - compare_const = s32[] constant(1030) // this valueis bigger than the max of add + + compare_const = s32[] constant(1030) compare1 = pred[] compare(%add, compare_const), direction=LT select1 = s32[] select(compare1, region1_param, %add) - + ROOT out = (s32[], s32[]) tuple(%add, %add) } region2 { region2_param = s32[] parameter(0) region2_c0 = s32[] constant(1024) %mult = s32[] multiply(region2_param, region2_c0) - - compare_const = s32[] constant(5121) // this valueis bigger than the max of mult + + compare_const = s32[] constant(5121) compare2 = pred[] compare(%mult, compare_const), direction=LT select2 = s32[] select(compare2, region2_param, %mult) - + ROOT out = (s32[], s32[]) tuple(%mult, %mult) } ENTRY entry { p0 = s32[] parameter(0) branch_index = s32[] parameter(1) - ROOT conditional.1 = (s32[], s32[]) conditional(branch_index, p0, p0), branch_computations={region1, region2} + ROOT conditional.1 = (s32[], s32[]) conditional(branch_index, p0, p0), + branch_computations={region1, region2} } )"; auto module = @@ -257,7 +307,7 @@ ENTRY entry { EXPECT_FALSE(range.IsSingleValue()); EXPECT_TRUE(range.IsLinear()); EXPECT_EQ(range.min().GetUnsignedValue(), 32768); - EXPECT_EQ(range.max().GetUnsignedValue(), 32773); + EXPECT_EQ(range.max()->GetUnsignedValue(), 32773); } TEST_F(ValueRangeTest, SubtractValue) { @@ -283,7 +333,7 @@ ENTRY entry { EXPECT_FALSE(range.IsSingleValue()); EXPECT_TRUE(range.IsLinear()); EXPECT_EQ(range.min().GetSignedValue(), -124); - EXPECT_EQ(range.max().GetSignedValue(), -119); + EXPECT_EQ(range.max()->GetSignedValue(), -119); } TEST_F(ValueRangeTest, SelectValue) { @@ -311,7 +361,7 @@ ENTRY entry { EXPECT_FALSE(range.IsEmpty()); EXPECT_FALSE(range.IsSingleValue()); EXPECT_TRUE(range.IsLinear()); - EXPECT_EQ(range.max().GetSignedValue(), -119); + EXPECT_EQ(range.max()->GetSignedValue(), -119); EXPECT_EQ(range.min().GetSignedValue(), -124); } @@ -340,10 +390,47 @@ ENTRY entry { EXPECT_FALSE(range.IsEmpty()); EXPECT_FALSE(range.IsSingleValue()); EXPECT_TRUE(range.IsLinear()); - EXPECT_EQ(range.max().GetSignedValue(), 129); + EXPECT_EQ(range.max()->GetSignedValue(), 129); EXPECT_EQ(range.min().GetSignedValue(), 124); } +TEST_F(ValueRangeTest, SelectBoundedFromUnboundedRange) { + constexpr absl::string_view hlo_string = R"( +HloModule module + +ENTRY entry { + p0 = s32[] parameter(0) + p1 = s32[] parameter(1) + ROOT %s = s32[] subtract(p0, p1) +} +)"; + auto module = + ParseAndReturnUnverifiedModule(hlo_string, HloModuleConfig{}).value(); + const HloInstruction* root = module->entry_computation()->root_instruction(); + const HloInstruction* p0 = + module->entry_computation()->parameter_instruction(0); + const HloInstruction* p1 = + module->entry_computation()->parameter_instruction(1); + absl::flat_hash_map fs; + // p0 has range min = 1, max = Unknown, step = 2 + fs.insert(std::make_pair( + p0, Range{/*min=*/ConstantValue::GetSigned(1, 32), + /*max=*/std::nullopt, + /*step=*/ConstantValue::GetUnsigned(2, /*bitwidth=*/32), + /*is_linear=*/true})); + // p1 has range min = 0, max = 10, step = 2 + fs.insert(std::make_pair( + p1, Range{/*min=*/ConstantValue::GetZero(32, /*is_signed=*/true), + /*max=*/ConstantValue::GetSigned(10, 32), + /*step=*/ConstantValue::GetUnsigned(2, /*bitwidth=*/32), + /*is_linear=*/true})); + auto range = RecursivelyIdentifyRange(root, fs); + EXPECT_FALSE(range.IsSingleValue()); + EXPECT_TRUE(range.IsLinear()); + EXPECT_FALSE(range.IsBounded()); + EXPECT_EQ(range.min().GetSignedValue(), 1 - 10); +} + TEST_F(ValueRangeTest, AddSubtractValue) { constexpr absl::string_view hlo_string = R"( HloModule module @@ -371,7 +458,7 @@ ENTRY entry { EXPECT_FALSE(range.IsSingleValue()); EXPECT_TRUE(range.IsLinear()); EXPECT_EQ(range.min().GetSignedValue(), 112); - EXPECT_EQ(range.max().GetSignedValue(), 117); + EXPECT_EQ(range.max()->GetSignedValue(), 117); } TEST_F(ValueRangeTest, SubtractWrapAroundValue) { @@ -389,10 +476,10 @@ ENTRY entry { const HloInstruction* root = module->entry_computation()->root_instruction(); const HloInstruction* p0 = root->operand(0); absl::flat_hash_map fs; - fs.insert( - std::make_pair(p0, Range{ConstantValue::GetSigned(-32768, 16), - ConstantValue::GetZero(16, /*is_signed=*/true), - /*is_linear=*/true})); + fs.insert(std::make_pair(p0, Range{ConstantValue::GetSigned(-32768, 16), + ConstantValue::GetZero(16, + /*is_signed=*/true), + /*is_linear=*/true})); auto range = RecursivelyIdentifyRange(root, fs); EXPECT_TRUE(range.IsEmpty()); EXPECT_FALSE(range.IsSingleValue()); @@ -414,10 +501,10 @@ ENTRY entry { const HloInstruction* root = module->entry_computation()->root_instruction(); const HloInstruction* p0 = root->operand(0); absl::flat_hash_map fs; - fs.insert( - std::make_pair(p0, Range{ConstantValue::GetZero(16, /*is_signed=*/true), - ConstantValue::GetSigned(32760, 16), - /*is_linear=*/true})); + fs.insert(std::make_pair(p0, Range{ConstantValue::GetZero(16, + /*is_signed=*/true), + ConstantValue::GetSigned(32760, 16), + /*is_linear=*/true})); auto range = RecursivelyIdentifyRange(root, fs); EXPECT_TRUE(range.IsEmpty()); EXPECT_FALSE(range.IsSingleValue());