Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make max value in Range optional to allow for Unbounded Range calculations. #20495

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions xla/hlo/analysis/while_loop_analysis_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -301,8 +301,8 @@ bool RangeEqualIgnoreBitwidth(const Range& range, int init, int limit,
: r.min().GetUnsignedValue();
};
auto range_max = [](const Range& r) {
return r.min().IsSigned() ? r.max().GetSignedValue()
: r.max().GetUnsignedValue();
return r.min().IsSigned() ? r.max()->GetSignedValue()
: r.max()->GetUnsignedValue();
};
return range_min(range) == init && range_max(range) == limit &&
range.step().GetSignedValue() == step;
Expand Down
8 changes: 3 additions & 5 deletions xla/service/collective_pipeliner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ std::optional<int> GetSlicedDimension(

bool CheckIndexIsMonotonic(
const HloInstruction* index,
const absl::flat_hash_map<const HloInstruction*, Range>& induction_map) {
absl::flat_hash_map<const HloInstruction*, Range>& induction_map) {
// Because the only math operations supported by RecursivelyIdentifyRange()
// are only sub/add then checking that we can compute the range here is enough
// to guarantee that the index is monotonic if the base index is monotonic. If
Expand Down Expand Up @@ -789,8 +789,7 @@ class WhileLoopAnalysis {
CollectivePipeliner::PipeliningDirection direction,
int64_t level_to_operate_on,
const absl::flat_hash_map<int64_t, int64_t>& parameter_gtes_count,
const absl::flat_hash_map<const HloInstruction*, Range>& index_ranges)
const;
absl::flat_hash_map<const HloInstruction*, Range>& index_ranges) const;

// Merges the new collective (instr) with the existing one stored in
// move_infos_[indices_to_merge[0]]. indices_to_merge.size() should be 1.
Expand Down Expand Up @@ -981,8 +980,7 @@ WhileLoopAnalysis::IsSupportedDynamicUpdateSlice(
CollectivePipeliner::PipeliningDirection direction,
int64_t level_to_operate_on,
const absl::flat_hash_map<int64_t, int64_t>& parameter_gtes_count,
const absl::flat_hash_map<const HloInstruction*, Range>& index_ranges)
const {
absl::flat_hash_map<const HloInstruction*, Range>& index_ranges) const {
HloComputation* while_body = while_->while_body();
const HloInstruction* loop_parameter =
while_body->parameter_instructions()[0];
Expand Down
198 changes: 129 additions & 69 deletions xla/service/value_range.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License.
#include <cstdint>
#include <optional>
#include <string>
#include <vector>

#include "absl/container/flat_hash_map.h"
#include "absl/log/log.h"
Expand Down Expand Up @@ -54,7 +55,8 @@ std::string Range::ToString() const {
return min_.ToString();
}
return absl::StrCat(
"min: ", min_.ToString(), " max: ", max_.ToString(),
"min: ", min_.ToString(),
" max: ", IsBounded() ? max_.value().ToString() : "Unknown",
" step: ", IsStepKnown() ? step_.value().ToString() : "Unknown");
}

Expand All @@ -69,88 +71,113 @@ std::optional<ConstantValue> FindStepForBinaryOp(const Range& lhs,
if (rhs.IsSingleValue()) {
return lhs.step();
}
if (lhs.step().eq(rhs.step())) {
if (lhs.step()->eq(rhs.step().value())) {
return lhs.step();
}
return std::nullopt;
}

Range RecordAndReturnRange(
const Range& range, const HloInstruction* instr,
absl::flat_hash_map<const HloInstruction*, Range>& known_ranges) {
known_ranges[instr] = range;
VLOG(5) << "Computed range for: " << instr->name() << " -> "
<< range.ToString();
return range;
}

// Identify the value ranges of a scalar HLO with a integer type. It returns
// a range of values that the instruction can have.
Range RecursivelyIdentifyRange(
const HloInstruction* instr,
const absl::flat_hash_map<const HloInstruction*, Range>& predefined_ranges,
absl::flat_hash_map<const HloInstruction*, Range>& known_ranges,
const HloAliasAnalysis* alias_analysis) {
// Non scalar or non-integer HLO. Abort.
if ((!instr->shape().IsInteger() && instr->shape().element_type() != PRED) ||
instr->shape().dimensions_size() != 0) {
return Range{};
}
VLOG(5) << "Computing Range for " << instr->ToString();
auto it = predefined_ranges.find(instr);
if (it != predefined_ranges.end()) {
VLOG(5) << "Found range! " << it->second.max().GetSignedValue() << " "
<< it->second.min().GetSignedValue();
auto it = known_ranges.find(instr);
if (it != known_ranges.end()) {
VLOG(5) << "Found range: " << it->second.ToString();
return it->second;
} else if (alias_analysis != nullptr) {
auto value_set =
alias_analysis->dataflow_analysis().GetFlattenedValueSet(instr);
for (const auto& value : value_set.TakeValues()) {
for (const HloPosition& position : value->positions()) {
auto it = predefined_ranges.find(position.instruction);
if (it != predefined_ranges.end()) {
VLOG(5) << "Found range in defining instruction! "
<< it->second.max().GetSignedValue() << " "
<< it->second.min().GetSignedValue();
auto it = known_ranges.find(position.instruction);
if (it != known_ranges.end()) {
VLOG(5) << "Found range in defining instruction: "
<< it->second.ToString();
return it->second;
}
}
}
}
switch (instr->opcode()) {
case HloOpcode::kGetTupleElement: {
if (alias_analysis != nullptr) {
auto value_set =
alias_analysis->dataflow_analysis().GetFlattenedValueSet(instr);
std::vector<const HloValue*> values = value_set.TakeValues();
if (values.size() != 1) {
VLOG(5) << "Ambiguous value set";
return Range{};
}
HloInstruction* defining_instruction =
values.at(0)->defining_instruction();
if (defining_instruction != nullptr) {
return RecursivelyIdentifyRange(defining_instruction, known_ranges,
alias_analysis);
}
}
return Range{};
}
case HloOpcode::kCompare: {
VLOG(5) << "Handling Compare";
Range lhs = RecursivelyIdentifyRange(instr->operand(0), predefined_ranges,
Range lhs = RecursivelyIdentifyRange(instr->operand(0), known_ranges,
alias_analysis);
Range rhs = RecursivelyIdentifyRange(instr->operand(1), predefined_ranges,
Range rhs = RecursivelyIdentifyRange(instr->operand(1), known_ranges,
alias_analysis);
VLOG(5) << "Returned Rhs: " << rhs.ToString()
<< " Lhs: " << lhs.ToString();
// Only kLt supported right now.
if (instr->comparison_direction() != ComparisonDirection::kLt) {
return Range{};
}
if (lhs.max().lt(rhs.min())) {
return Range{ConstantValue::GetOne(/*bitwidth=*/1, /*is_signed=*/false),
ConstantValue::GetOne(/*bitwidth=*/1, /*is_signed=*/false),
/*is_linear=*/true};
if (lhs.IsBounded() && lhs.max()->lt(rhs.min())) {
return RecordAndReturnRange(
Range{ConstantValue::GetOne(/*bitwidth=*/1, /*is_signed=*/false),
ConstantValue::GetOne(/*bitwidth=*/1, /*is_signed=*/false),
/*is_linear=*/true},
instr, known_ranges);
}
if (!lhs.min().lt(rhs.max())) {
return Range{
ConstantValue::GetZero(/*bitwidth=*/1, /*is_signed=*/false),
ConstantValue::GetZero(/*bitwidth=*/1, /*is_signed=*/false),
/*is_linear=*/true};
if (rhs.IsBounded() && !lhs.min().lt(rhs.max().value())) {
return RecordAndReturnRange(
Range{ConstantValue::GetZero(/*bitwidth=*/1, /*is_signed=*/false),
ConstantValue::GetZero(/*bitwidth=*/1, /*is_signed=*/false),
/*is_linear=*/true},
instr, known_ranges);
}
VLOG(5) << "Compare failed";
VLOG(5) << "rhs max " << rhs.max().GetSignedValue() << " rhs min "
<< rhs.min().GetSignedValue() << " lhs max "
<< lhs.max().GetSignedValue() << " lhs min "
<< lhs.min().GetSignedValue();
return Range{};
}
case HloOpcode::kConstant: {
if (instr->shape().element_type() == PRED &&
instr->shape().dimensions_size() == 0) {
if (instr->literal().IsAll(true)) {
return Range{
ConstantValue::GetOne(/*bitwidth=*/1, /*is_signed=*/false),
ConstantValue::GetOne(/*bitwidth=*/1, /*is_signed=*/false),
/*is_linear=*/true};
return RecordAndReturnRange(
Range{ConstantValue::GetOne(/*bitwidth=*/1, /*is_signed=*/false),
ConstantValue::GetOne(/*bitwidth=*/1, /*is_signed=*/false),
/*is_linear=*/true},
instr, known_ranges);
}
return Range{
ConstantValue::GetZero(/*bitwidth=*/1, /*is_signed=*/false),
ConstantValue::GetZero(/*bitwidth=*/1, /*is_signed=*/false),
/*is_linear=*/true};
return RecordAndReturnRange(
Range{ConstantValue::GetZero(/*bitwidth=*/1, /*is_signed=*/false),
ConstantValue::GetZero(/*bitwidth=*/1, /*is_signed=*/false),
/*is_linear=*/true},
instr, known_ranges);
}
if (!instr->shape().IsInteger()) {
return Range{};
Expand All @@ -162,48 +189,59 @@ Range RecursivelyIdentifyRange(
primitive_util::IsSignedIntegralType(instr->shape().element_type());
if (is_signed) {
const int64_t value = *instr->literal().GetFirstInteger();
return Range{ConstantValue::GetSigned(value, bitwidth),
ConstantValue::GetSigned(value, bitwidth),
ConstantValue::GetOne(/*bitwidth=*/1, /*is_signed=*/false),
/*is_linear=*/true};
return RecordAndReturnRange(
Range{ConstantValue::GetSigned(value, bitwidth),
ConstantValue::GetSigned(value, bitwidth),
ConstantValue::GetOne(/*bitwidth=*/1, /*is_signed=*/false),
/*is_linear=*/true},
instr, known_ranges);
}
const uint64_t value = *instr->literal().GetFirstInteger();
return Range{ConstantValue::GetUnsigned(value, bitwidth),
ConstantValue::GetUnsigned(value, bitwidth),
ConstantValue::GetOne(/*bitwidth=*/1, /*is_signed=*/false),
/*is_linear=*/true};
return RecordAndReturnRange(
Range{ConstantValue::GetUnsigned(value, bitwidth),
ConstantValue::GetUnsigned(value, bitwidth),
ConstantValue::GetOne(/*bitwidth=*/1, /*is_signed=*/false),
/*is_linear=*/true},
instr, known_ranges);
}
case HloOpcode::kAdd: {
if (!instr->shape().IsInteger()) {
return Range{};
}
VLOG(5) << "Handling Add";
Range lhs = RecursivelyIdentifyRange(instr->operand(0), predefined_ranges,
Range lhs = RecursivelyIdentifyRange(instr->operand(0), known_ranges,
alias_analysis);
Range rhs = RecursivelyIdentifyRange(instr->operand(1), predefined_ranges,
Range rhs = RecursivelyIdentifyRange(instr->operand(1), known_ranges,
alias_analysis);
VLOG(5) << "Returned Rhs: " << rhs.ToString()
<< " Lhs: " << lhs.ToString();
if (lhs.IsEmpty() || rhs.IsEmpty()) {
return Range{};
}
ConstantValue min = lhs.min().add(rhs.min());
ConstantValue max = lhs.max().add(rhs.max());
if (max.lt(min)) {
VLOG(5) << "Add wrapped";
return Range{};
std::optional<ConstantValue> step = FindStepForBinaryOp(lhs, rhs);
if (lhs.IsBounded() && rhs.IsBounded()) {
ConstantValue max = lhs.max()->add(rhs.max().value());
if (max.lt(min)) {
VLOG(5) << "Add wrapped";
return Range{};
}
return RecordAndReturnRange(
Range{min, max, step, lhs.IsLinear() && rhs.IsLinear()}, instr,
known_ranges);
}
return Range{min, max, FindStepForBinaryOp(lhs, rhs),
lhs.IsLinear() && rhs.IsLinear()};
return RecordAndReturnRange(
Range{min, std::nullopt, step, lhs.IsLinear() && rhs.IsLinear()},
instr, known_ranges);
}
case HloOpcode::kMultiply: {
if (!instr->shape().IsInteger()) {
return Range{};
}
VLOG(5) << "Handling Multiply";
Range lhs = RecursivelyIdentifyRange(instr->operand(0), predefined_ranges,
Range lhs = RecursivelyIdentifyRange(instr->operand(0), known_ranges,
alias_analysis);
Range rhs = RecursivelyIdentifyRange(instr->operand(1), predefined_ranges,
Range rhs = RecursivelyIdentifyRange(instr->operand(1), known_ranges,
alias_analysis);
VLOG(5) << "Returned Rhs: " << rhs.ToString()
<< " Lhs: " << lhs.ToString();
Expand All @@ -219,52 +257,74 @@ Range RecursivelyIdentifyRange(
// When multiplying with a constant, min, max, and step are all
// multiplied by the single value.
ConstantValue min = operand_range.min().mul(single_value);
ConstantValue max = operand_range.max().mul(single_value);
if (operand_range.IsBounded()) {
ConstantValue max = operand_range.max()->mul(single_value);
if (!operand_range.IsStepKnown()) {
return RecordAndReturnRange(Range{min, max, operand_range.IsLinear()},
instr, known_ranges);
}
ConstantValue step = operand_range.step()->mul(single_value);
return RecordAndReturnRange(
Range{min, max, step, operand_range.IsLinear()}, instr,
known_ranges);
}
if (!operand_range.IsStepKnown()) {
return Range{min, max, operand_range.IsLinear()};
return RecordAndReturnRange(
Range{min, std::nullopt, operand_range.IsLinear()}, instr,
known_ranges);
}
ConstantValue step = operand_range.step().mul(single_value);
return Range{min, max, step, operand_range.IsLinear()};
ConstantValue step = operand_range.step()->mul(single_value);
return RecordAndReturnRange(
Range{min, std::nullopt, step, operand_range.IsLinear()}, instr,
known_ranges);
}
case HloOpcode::kSelect: {
VLOG(5) << "Handling Select: " << instr->ToString();
const HloInstruction* cmp = instr->operand(0);
Range cmp_range =
RecursivelyIdentifyRange(cmp, predefined_ranges, alias_analysis);
RecursivelyIdentifyRange(cmp, known_ranges, alias_analysis);
// Support only when the select has a constant value as condition.
if (cmp_range.IsEmpty() || !cmp_range.IsSingleValue()) {
VLOG(5) << "Select failed";
return Range{};
}
if (cmp_range.GetSingleSignedValue() == 0) {
return RecursivelyIdentifyRange(instr->operand(2), predefined_ranges,
alias_analysis);
return RecordAndReturnRange(
RecursivelyIdentifyRange(instr->operand(2), known_ranges,
alias_analysis),
instr, known_ranges);
}
return RecursivelyIdentifyRange(instr->operand(1), predefined_ranges,
alias_analysis);
return RecordAndReturnRange(
RecursivelyIdentifyRange(instr->operand(1), known_ranges,
alias_analysis),
instr, known_ranges);
}
case HloOpcode::kSubtract: {
if (!instr->shape().IsInteger()) {
return Range{};
}
VLOG(5) << "Handling Subtract";
Range lhs = RecursivelyIdentifyRange(instr->operand(0), predefined_ranges,
Range lhs = RecursivelyIdentifyRange(instr->operand(0), known_ranges,
alias_analysis);
Range rhs = RecursivelyIdentifyRange(instr->operand(1), predefined_ranges,
Range rhs = RecursivelyIdentifyRange(instr->operand(1), known_ranges,
alias_analysis);
VLOG(5) << "Returned Rhs: " << rhs.ToString()
<< " Lhs: " << lhs.ToString();
if (lhs.IsEmpty() || rhs.IsEmpty()) {
return Range{};
}
ConstantValue min = lhs.min().sub(rhs.max());
ConstantValue max = lhs.max().sub(rhs.min());
if (!lhs.IsBounded() || !rhs.IsBounded()) {
return Range{};
}
ConstantValue min = lhs.min().sub(rhs.max().value());
ConstantValue max = lhs.max()->sub(rhs.min());
if (max.lt(min)) {
VLOG(5) << "Subtract wrapped";
return Range{};
}
return Range{min, max, FindStepForBinaryOp(lhs, rhs),
lhs.IsLinear() && rhs.IsLinear()};
return RecordAndReturnRange(Range{min, max, FindStepForBinaryOp(lhs, rhs),
lhs.IsLinear() && rhs.IsLinear()},
instr, known_ranges);
}
default:
break;
Expand Down
Loading