Skip to content

Commit

Permalink
Make max value in Range optional to allow for Unbounded Range calcula…
Browse files Browse the repository at this point in the history
…tions.

Also, cache the intermediate calculated ranges when calling RecrusivelyIdentifyRange.

PiperOrigin-RevId: 705686131
  • Loading branch information
fhoushmand authored and Google-ML-Automation committed Dec 13, 2024
1 parent a803260 commit 71d1c7a
Show file tree
Hide file tree
Showing 5 changed files with 249 additions and 130 deletions.
4 changes: 2 additions & 2 deletions xla/hlo/analysis/while_loop_analysis_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -301,8 +301,8 @@ bool RangeEqualIgnoreBitwidth(const Range& range, int init, int limit,
: r.min().GetUnsignedValue();
};
auto range_max = [](const Range& r) {
return r.min().IsSigned() ? r.max().GetSignedValue()
: r.max().GetUnsignedValue();
return r.min().IsSigned() ? r.max()->GetSignedValue()
: r.max()->GetUnsignedValue();
};
return range_min(range) == init && range_max(range) == limit &&
range.step().GetSignedValue() == step;
Expand Down
8 changes: 3 additions & 5 deletions xla/service/collective_pipeliner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ std::optional<int> GetSlicedDimension(

bool CheckIndexIsMonotonic(
const HloInstruction* index,
const absl::flat_hash_map<const HloInstruction*, Range>& induction_map) {
absl::flat_hash_map<const HloInstruction*, Range>& induction_map) {
// Because the only math operations supported by RecursivelyIdentifyRange()
// are only sub/add then checking that we can compute the range here is enough
// to guarantee that the index is monotonic if the base index is monotonic. If
Expand Down Expand Up @@ -789,8 +789,7 @@ class WhileLoopAnalysis {
CollectivePipeliner::PipeliningDirection direction,
int64_t level_to_operate_on,
const absl::flat_hash_map<int64_t, int64_t>& parameter_gtes_count,
const absl::flat_hash_map<const HloInstruction*, Range>& index_ranges)
const;
absl::flat_hash_map<const HloInstruction*, Range>& index_ranges) const;

// Merges the new collective (instr) with the existing one stored in
// move_infos_[indices_to_merge[0]]. indices_to_merge.size() should be 1.
Expand Down Expand Up @@ -981,8 +980,7 @@ WhileLoopAnalysis::IsSupportedDynamicUpdateSlice(
CollectivePipeliner::PipeliningDirection direction,
int64_t level_to_operate_on,
const absl::flat_hash_map<int64_t, int64_t>& parameter_gtes_count,
const absl::flat_hash_map<const HloInstruction*, Range>& index_ranges)
const {
absl::flat_hash_map<const HloInstruction*, Range>& index_ranges) const {
HloComputation* while_body = while_->while_body();
const HloInstruction* loop_parameter =
while_body->parameter_instructions()[0];
Expand Down
204 changes: 135 additions & 69 deletions xla/service/value_range.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ std::string Range::ToString() const {
return min_.ToString();
}
return absl::StrCat(
"min: ", min_.ToString(), " max: ", max_.ToString(),
"min: ", min_.ToString(),
" max: ", IsBounded() ? max_.value().ToString() : "Unknown",
" step: ", IsStepKnown() ? step_.value().ToString() : "Unknown");
}

Expand All @@ -75,82 +76,107 @@ std::optional<ConstantValue> FindStepForBinaryOp(const Range& lhs,
return std::nullopt;
}

Range RecordAndReturnRange(
const Range& range, const HloInstruction* instr,
absl::flat_hash_map<const HloInstruction*, Range>& known_ranges) {
known_ranges[instr] = range;
VLOG(5) << "Computed range for: " << instr->name() << " -> "
<< range.ToString();
return range;
}

// Identify the value ranges of a scalar HLO with a integer type. It returns
// a range of values that the instruction can have.
Range RecursivelyIdentifyRange(
const HloInstruction* instr,
const absl::flat_hash_map<const HloInstruction*, Range>& predefined_ranges,
absl::flat_hash_map<const HloInstruction*, Range>& known_ranges,
const HloAliasAnalysis* alias_analysis) {
// Non scalar or non-integer HLO. Abort.
if ((!instr->shape().IsInteger() && instr->shape().element_type() != PRED) ||
instr->shape().dimensions_size() != 0) {
return Range{};
}
VLOG(5) << "Computing Range for " << instr->ToString();
auto it = predefined_ranges.find(instr);
if (it != predefined_ranges.end()) {
VLOG(5) << "Found range! " << it->second.max().GetSignedValue() << " "
<< it->second.min().GetSignedValue();
auto it = known_ranges.find(instr);
if (it != known_ranges.end()) {
VLOG(5) << "Found range: " << it->second.ToString();
return it->second;
} else if (alias_analysis != nullptr) {
auto value_set =
alias_analysis->dataflow_analysis().GetFlattenedValueSet(instr);
for (const auto& value : value_set.TakeValues()) {
for (const HloPosition& position : value->positions()) {
auto it = predefined_ranges.find(position.instruction);
if (it != predefined_ranges.end()) {
VLOG(5) << "Found range in defining instruction! "
<< it->second.max().GetSignedValue() << " "
<< it->second.min().GetSignedValue();
auto it = known_ranges.find(position.instruction);
if (it != known_ranges.end()) {
VLOG(5) << "Found range in defining instruction: "
<< it->second.ToString();
return it->second;
}
}
}
}
switch (instr->opcode()) {
case HloOpcode::kGetTupleElement: {
if (alias_analysis != nullptr) {
auto value_set =
alias_analysis->dataflow_analysis().GetFlattenedValueSet(instr);
std::vector<const HloValue*> values = value_set.TakeValues();
if (values.size() != 1) {
VLOG(5) << "Ambiguous value set";
return Range{};
}
HloInstruction* defining_instruction =
values.at(0)->defining_instruction();
if (defining_instruction != nullptr) {
return RecursivelyIdentifyRange(defining_instruction, known_ranges,
alias_analysis);
}
}
return Range{};
}
case HloOpcode::kCompare: {
VLOG(5) << "Handling Compare";
Range lhs = RecursivelyIdentifyRange(instr->operand(0), predefined_ranges,
Range lhs = RecursivelyIdentifyRange(instr->operand(0), known_ranges,
alias_analysis);
Range rhs = RecursivelyIdentifyRange(instr->operand(1), predefined_ranges,
Range rhs = RecursivelyIdentifyRange(instr->operand(1), known_ranges,
alias_analysis);
VLOG(5) << "Returned Rhs: " << rhs.ToString()
<< " Lhs: " << lhs.ToString();
// Only kLt supported right now.
if (instr->comparison_direction() != ComparisonDirection::kLt) {
return Range{};
}
if (lhs.max().lt(rhs.min())) {
return Range{ConstantValue::GetOne(/*bitwidth=*/1, /*is_signed=*/false),
ConstantValue::GetOne(/*bitwidth=*/1, /*is_signed=*/false),
/*is_linear=*/true};
if (lhs.IsBounded() && lhs.max()->lt(rhs.min())) {
return RecordAndReturnRange(
Range{ConstantValue::GetOne(/*bitwidth=*/1, /*is_signed=*/false),
ConstantValue::GetOne(/*bitwidth=*/1, /*is_signed=*/false),
/*is_linear=*/true},
instr, known_ranges);
}
if (!lhs.min().lt(rhs.max())) {
return Range{
ConstantValue::GetZero(/*bitwidth=*/1, /*is_signed=*/false),
ConstantValue::GetZero(/*bitwidth=*/1, /*is_signed=*/false),
/*is_linear=*/true};
if (rhs.IsBounded() && !lhs.min().lt(rhs.max().value())) {
return RecordAndReturnRange(
Range{ConstantValue::GetZero(/*bitwidth=*/1, /*is_signed=*/false),
ConstantValue::GetZero(/*bitwidth=*/1, /*is_signed=*/false),
/*is_linear=*/true},
instr, known_ranges);
}
VLOG(5) << "Compare failed";
VLOG(5) << "rhs max " << rhs.max().GetSignedValue() << " rhs min "
<< rhs.min().GetSignedValue() << " lhs max "
<< lhs.max().GetSignedValue() << " lhs min "
<< lhs.min().GetSignedValue();
return Range{};
}
case HloOpcode::kConstant: {
if (instr->shape().element_type() == PRED &&
instr->shape().dimensions_size() == 0) {
if (instr->literal().IsAll(true)) {
return Range{
ConstantValue::GetOne(/*bitwidth=*/1, /*is_signed=*/false),
ConstantValue::GetOne(/*bitwidth=*/1, /*is_signed=*/false),
/*is_linear=*/true};
return RecordAndReturnRange(
Range{ConstantValue::GetOne(/*bitwidth=*/1, /*is_signed=*/false),
ConstantValue::GetOne(/*bitwidth=*/1, /*is_signed=*/false),
/*is_linear=*/true},
instr, known_ranges);
}
return Range{
ConstantValue::GetZero(/*bitwidth=*/1, /*is_signed=*/false),
ConstantValue::GetZero(/*bitwidth=*/1, /*is_signed=*/false),
/*is_linear=*/true};
return RecordAndReturnRange(
Range{ConstantValue::GetZero(/*bitwidth=*/1, /*is_signed=*/false),
ConstantValue::GetZero(/*bitwidth=*/1, /*is_signed=*/false),
/*is_linear=*/true},
instr, known_ranges);
}
if (!instr->shape().IsInteger()) {
return Range{};
Expand All @@ -162,48 +188,60 @@ Range RecursivelyIdentifyRange(
primitive_util::IsSignedIntegralType(instr->shape().element_type());
if (is_signed) {
const int64_t value = *instr->literal().GetFirstInteger();
return Range{ConstantValue::GetSigned(value, bitwidth),
ConstantValue::GetSigned(value, bitwidth),
ConstantValue::GetOne(/*bitwidth=*/1, /*is_signed=*/false),
/*is_linear=*/true};
return RecordAndReturnRange(
Range{ConstantValue::GetSigned(value, bitwidth),
ConstantValue::GetSigned(value, bitwidth),
ConstantValue::GetOne(/*bitwidth=*/1, /*is_signed=*/false),
/*is_linear=*/true},
instr, known_ranges);
}
const uint64_t value = *instr->literal().GetFirstInteger();
return Range{ConstantValue::GetUnsigned(value, bitwidth),
ConstantValue::GetUnsigned(value, bitwidth),
ConstantValue::GetOne(/*bitwidth=*/1, /*is_signed=*/false),
/*is_linear=*/true};
return RecordAndReturnRange(
Range{ConstantValue::GetUnsigned(value, bitwidth),
ConstantValue::GetUnsigned(value, bitwidth),
ConstantValue::GetOne(/*bitwidth=*/1, /*is_signed=*/false),
/*is_linear=*/true},
instr, known_ranges);
}
case HloOpcode::kAdd: {
if (!instr->shape().IsInteger()) {
return Range{};
}
VLOG(5) << "Handling Add";
Range lhs = RecursivelyIdentifyRange(instr->operand(0), predefined_ranges,
Range lhs = RecursivelyIdentifyRange(instr->operand(0), known_ranges,
alias_analysis);
Range rhs = RecursivelyIdentifyRange(instr->operand(1), predefined_ranges,
Range rhs = RecursivelyIdentifyRange(instr->operand(1), known_ranges,
alias_analysis);
VLOG(5) << "Returned Rhs: " << rhs.ToString()
<< " Lhs: " << lhs.ToString();
if (lhs.IsEmpty() || rhs.IsEmpty()) {
return Range{};
}
ConstantValue min = lhs.min().add(rhs.min());
ConstantValue max = lhs.max().add(rhs.max());
if (max.lt(min)) {
VLOG(5) << "Add wrapped";
return Range{};
if (lhs.IsBounded() && rhs.IsBounded()) {
ConstantValue max = lhs.max()->add(rhs.max().value());
if (max.lt(min)) {
VLOG(5) << "Add wrapped";
return Range{};
}
return RecordAndReturnRange(
Range{min, max, FindStepForBinaryOp(lhs, rhs).value(),
lhs.IsLinear() && rhs.IsLinear()},
instr, known_ranges);
}
return Range{min, max, FindStepForBinaryOp(lhs, rhs),
lhs.IsLinear() && rhs.IsLinear()};
return RecordAndReturnRange(
Range{min, std::nullopt, FindStepForBinaryOp(lhs, rhs).value(),
lhs.IsLinear() && rhs.IsLinear()},
instr, known_ranges);
}
case HloOpcode::kMultiply: {
if (!instr->shape().IsInteger()) {
return Range{};
}
VLOG(5) << "Handling Multiply";
Range lhs = RecursivelyIdentifyRange(instr->operand(0), predefined_ranges,
Range lhs = RecursivelyIdentifyRange(instr->operand(0), known_ranges,
alias_analysis);
Range rhs = RecursivelyIdentifyRange(instr->operand(1), predefined_ranges,
Range rhs = RecursivelyIdentifyRange(instr->operand(1), known_ranges,
alias_analysis);
VLOG(5) << "Returned Rhs: " << rhs.ToString()
<< " Lhs: " << lhs.ToString();
Expand All @@ -215,51 +253,79 @@ Range RecursivelyIdentifyRange(
return Range{};
}
ConstantValue single_value = lhs.IsSingleValue() ? lhs.min() : rhs.min();
ConstantValue min = lhs.IsSingleValue() ? rhs.min().mul(single_value)
: lhs.min().mul(single_value);
ConstantValue max = lhs.IsSingleValue() ? rhs.max().mul(single_value)
: lhs.max().mul(single_value);
return Range{min, max, single_value, lhs.IsLinear() && rhs.IsLinear()};
Range operand_range = lhs.IsSingleValue() ? rhs : lhs;
// When multiplying with a constant, min, max, and step are all
// multiplied by the single value.
ConstantValue min = operand_range.min().mul(single_value);
if (operand_range.IsBounded()) {
ConstantValue max = operand_range.max()->mul(single_value);
if (!operand_range.IsStepKnown()) {
return RecordAndReturnRange(Range{min, max, operand_range.IsLinear()},
instr, known_ranges);
}
ConstantValue step = operand_range.step().mul(single_value);
return RecordAndReturnRange(
Range{min, max, step, operand_range.IsLinear()}, instr,
known_ranges);
}
if (!operand_range.IsStepKnown()) {
return RecordAndReturnRange(
Range{min, std::nullopt, operand_range.IsLinear()}, instr,
known_ranges);
}
ConstantValue step = operand_range.step().mul(single_value);
return RecordAndReturnRange(
Range{min, std::nullopt, step, operand_range.IsLinear()}, instr,
known_ranges);
}
case HloOpcode::kSelect: {
VLOG(5) << "Handling Select: " << instr->ToString();
const HloInstruction* cmp = instr->operand(0);
Range cmp_range =
RecursivelyIdentifyRange(cmp, predefined_ranges, alias_analysis);
RecursivelyIdentifyRange(cmp, known_ranges, alias_analysis);
// Support only when the select has a constant value as condition.
if (cmp_range.IsEmpty() || !cmp_range.IsSingleValue()) {
VLOG(5) << "Select failed";
return Range{};
}
if (cmp_range.GetSingleSignedValue() == 0) {
return RecursivelyIdentifyRange(instr->operand(2), predefined_ranges,
alias_analysis);
return RecordAndReturnRange(
RecursivelyIdentifyRange(instr->operand(2), known_ranges,
alias_analysis),
instr, known_ranges);
}
return RecursivelyIdentifyRange(instr->operand(1), predefined_ranges,
alias_analysis);
return RecordAndReturnRange(
RecursivelyIdentifyRange(instr->operand(1), known_ranges,
alias_analysis),
instr, known_ranges);
}
case HloOpcode::kSubtract: {
if (!instr->shape().IsInteger()) {
return Range{};
}
VLOG(5) << "Handling Subtract";
Range lhs = RecursivelyIdentifyRange(instr->operand(0), predefined_ranges,
Range lhs = RecursivelyIdentifyRange(instr->operand(0), known_ranges,
alias_analysis);
Range rhs = RecursivelyIdentifyRange(instr->operand(1), predefined_ranges,
Range rhs = RecursivelyIdentifyRange(instr->operand(1), known_ranges,
alias_analysis);
VLOG(5) << "Returned Rhs: " << rhs.ToString()
<< " Lhs: " << lhs.ToString();
if (lhs.IsEmpty() || rhs.IsEmpty()) {
return Range{};
}
ConstantValue min = lhs.min().sub(rhs.max());
ConstantValue max = lhs.max().sub(rhs.min());
if (!lhs.IsBounded() || !rhs.IsBounded()) {
return Range{};
}
ConstantValue min = lhs.min().sub(rhs.max().value());
ConstantValue max = lhs.max()->sub(rhs.min());
if (max.lt(min)) {
VLOG(5) << "Subtract wrapped";
return Range{};
}
return Range{min, max, FindStepForBinaryOp(lhs, rhs),
lhs.IsLinear() && rhs.IsLinear()};
return RecordAndReturnRange(
Range{min, max, FindStepForBinaryOp(lhs, rhs).value(),
lhs.IsLinear() && rhs.IsLinear()},
instr, known_ranges);
}
default:
break;
Expand Down
Loading

0 comments on commit 71d1c7a

Please sign in to comment.