Skip to content

Commit

Permalink
[XLA] Return the number of overlapping chunks instead of chunks thems…
Browse files Browse the repository at this point in the history
…elves for tracking outstanding prefetches/evictions

PiperOrigin-RevId: 705952176
  • Loading branch information
berkinilbeyi authored and Google-ML-Automation committed Dec 13, 2024
1 parent b2949d5 commit e1e71ef
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 10 deletions.
30 changes: 30 additions & 0 deletions xla/service/heap_simulator/heap_simulator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -944,6 +944,36 @@ bool BufferIntervalTree::Remove(int64_t start, int64_t end,
return true;
}

int BufferIntervalTree::NumChunksOverlappingInTime(int64_t start,
int64_t end) const {
int result = 0;
if (root_ == nullptr) {
return result;
}
std::vector<const BufferIntervalTreeNode*> visiting_stack;
visiting_stack.push_back(root_);
while (!visiting_stack.empty()) {
const BufferIntervalTreeNode* top = visiting_stack.back();
visiting_stack.pop_back();
if (start > top->subtree_end) {
continue;
}
if (top->left != nullptr) {
visiting_stack.push_back(top->left);
}
if (top->start <= end && top->end >= start) {
++result;
}
if (end < top->start) {
continue;
}
if (top->right != nullptr) {
visiting_stack.push_back(top->right);
}
}
return result;
}

std::vector<Chunk> BufferIntervalTree::ChunksOverlappingInTime(
int64_t start, int64_t end) const {
std::vector<Chunk> result;
Expand Down
4 changes: 4 additions & 0 deletions xla/service/heap_simulator/heap_simulator.h
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,10 @@ class BufferIntervalTree {
// Remove the interval from the tree. Returns true if the chunk is removed.
bool Remove(int64_t start, int64_t end, const Chunk& chunk);

// Returns the number of allocated chunks that overlap with the given time
// interval.
int NumChunksOverlappingInTime(int64_t start, int64_t end) const;

// Returns vector of allocated chunks that overlap with the given time
// interval.
std::vector<Chunk> ChunksOverlappingInTime(int64_t start, int64_t end) const;
Expand Down
6 changes: 6 additions & 0 deletions xla/service/heap_simulator/heap_simulator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1862,10 +1862,16 @@ TEST_F(IntervalTreeTest, InsertAndRemoveTwoLevelsLeft) {
BufferIntervalTree tree;
tree.Add(20, 36, chunk);
tree.Add(1, 45, chunk);
EXPECT_EQ(tree.NumChunksOverlappingInTime(10, 25), 2);
EXPECT_EQ(tree.NumChunksOverlappingInTime(5, 15), 1);
EXPECT_TRUE(tree.Remove(1, 45, chunk));
EXPECT_EQ(tree.NumChunksOverlappingInTime(10, 25), 1);
EXPECT_EQ(tree.NumChunksOverlappingInTime(5, 15), 0);
EXPECT_EQ(tree.GetRoot()->subtree_end, 36);
EXPECT_TRUE(tree.Remove(20, 36, chunk));
ASSERT_EQ(tree.GetRoot(), nullptr);
EXPECT_EQ(tree.NumChunksOverlappingInTime(10, 25), 0);
EXPECT_EQ(tree.NumChunksOverlappingInTime(5, 15), 0);
}

TEST_F(IntervalTreeTest, InsertAndRemoveTwoLevelsRight) {
Expand Down
16 changes: 6 additions & 10 deletions xla/service/memory_space_assignment/algorithm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4669,19 +4669,15 @@ bool MsaAlgorithm::ViolatesMaximumOutstandingAsyncCopies(

// Count the prefetches/evictions in the interval tree for the given interval.
if (is_prefetch) {
int64_t num_prefetches =
prefetch_interval_tree_
.ChunksOverlappingInTime(inclusive_start_time, end_time)
.size() +
num_additional_copies;
int64_t num_prefetches = prefetch_interval_tree_.NumChunksOverlappingInTime(
inclusive_start_time, end_time) +
num_additional_copies;
return num_prefetches >=
options_.max_outstanding_prefetches + extra_async_copy_limit;
} else {
int64_t num_evictions =
eviction_interval_tree_
.ChunksOverlappingInTime(inclusive_start_time, end_time)
.size() +
num_additional_copies;
int64_t num_evictions = eviction_interval_tree_.NumChunksOverlappingInTime(
inclusive_start_time, end_time) +
num_additional_copies;
return num_evictions >=
options_.max_outstanding_evictions + extra_async_copy_limit;
}
Expand Down

0 comments on commit e1e71ef

Please sign in to comment.