From 6fe11166eb025d749c45f87e49416b3976fb5c0d Mon Sep 17 00:00:00 2001 From: Berkin Ilbeyi Date: Wed, 11 Dec 2024 08:55:07 -0800 Subject: [PATCH] [XLA] Return the number of overlapping chunks instead of chunks themselves for tracking outstanding prefetches/evictions PiperOrigin-RevId: 705124111 --- xla/service/heap_simulator/heap_simulator.cc | 30 +++++++++++++++++++ xla/service/heap_simulator/heap_simulator.h | 4 +++ .../heap_simulator/heap_simulator_test.cc | 6 ++++ .../memory_space_assignment/algorithm.cc | 16 ++++------ 4 files changed, 46 insertions(+), 10 deletions(-) diff --git a/xla/service/heap_simulator/heap_simulator.cc b/xla/service/heap_simulator/heap_simulator.cc index 9ceb861e0fce2..76357be5ac39d 100644 --- a/xla/service/heap_simulator/heap_simulator.cc +++ b/xla/service/heap_simulator/heap_simulator.cc @@ -944,6 +944,36 @@ bool BufferIntervalTree::Remove(int64_t start, int64_t end, return true; } +int BufferIntervalTree::NumChunksOverlappingInTime(int64_t start, + int64_t end) const { + int result = 0; + if (root_ == nullptr) { + return result; + } + std::vector visiting_stack; + visiting_stack.push_back(root_); + while (!visiting_stack.empty()) { + const BufferIntervalTreeNode* top = visiting_stack.back(); + visiting_stack.pop_back(); + if (start > top->subtree_end) { + continue; + } + if (top->left != nullptr) { + visiting_stack.push_back(top->left); + } + if (top->start <= end && top->end >= start) { + ++result; + } + if (end < top->start) { + continue; + } + if (top->right != nullptr) { + visiting_stack.push_back(top->right); + } + } + return result; +} + std::vector BufferIntervalTree::ChunksOverlappingInTime( int64_t start, int64_t end) const { std::vector result; diff --git a/xla/service/heap_simulator/heap_simulator.h b/xla/service/heap_simulator/heap_simulator.h index 7328f87722b60..d81b29b52ad45 100644 --- a/xla/service/heap_simulator/heap_simulator.h +++ b/xla/service/heap_simulator/heap_simulator.h @@ -363,6 +363,10 @@ class BufferIntervalTree { // Remove the interval from the tree. Returns true if the chunk is removed. bool Remove(int64_t start, int64_t end, const Chunk& chunk); + // Returns the number of allocated chunks that overlap with the given time + // interval. + int NumChunksOverlappingInTime(int64_t start, int64_t end) const; + // Returns vector of allocated chunks that overlap with the given time // interval. std::vector ChunksOverlappingInTime(int64_t start, int64_t end) const; diff --git a/xla/service/heap_simulator/heap_simulator_test.cc b/xla/service/heap_simulator/heap_simulator_test.cc index d27dbd14d81cc..612e7b060d886 100644 --- a/xla/service/heap_simulator/heap_simulator_test.cc +++ b/xla/service/heap_simulator/heap_simulator_test.cc @@ -1862,10 +1862,16 @@ TEST_F(IntervalTreeTest, InsertAndRemoveTwoLevelsLeft) { BufferIntervalTree tree; tree.Add(20, 36, chunk); tree.Add(1, 45, chunk); + EXPECT_EQ(tree.NumChunksOverlappingInTime(10, 25), 2); + EXPECT_EQ(tree.NumChunksOverlappingInTime(5, 15), 1); EXPECT_TRUE(tree.Remove(1, 45, chunk)); + EXPECT_EQ(tree.NumChunksOverlappingInTime(10, 25), 1); + EXPECT_EQ(tree.NumChunksOverlappingInTime(5, 15), 0); EXPECT_EQ(tree.GetRoot()->subtree_end, 36); EXPECT_TRUE(tree.Remove(20, 36, chunk)); ASSERT_EQ(tree.GetRoot(), nullptr); + EXPECT_EQ(tree.NumChunksOverlappingInTime(10, 25), 0); + EXPECT_EQ(tree.NumChunksOverlappingInTime(5, 15), 0); } TEST_F(IntervalTreeTest, InsertAndRemoveTwoLevelsRight) { diff --git a/xla/service/memory_space_assignment/algorithm.cc b/xla/service/memory_space_assignment/algorithm.cc index 1f64dcc3df66a..0130867714bf9 100644 --- a/xla/service/memory_space_assignment/algorithm.cc +++ b/xla/service/memory_space_assignment/algorithm.cc @@ -4669,19 +4669,15 @@ bool MsaAlgorithm::ViolatesMaximumOutstandingAsyncCopies( // Count the prefetches/evictions in the interval tree for the given interval. if (is_prefetch) { - int64_t num_prefetches = - prefetch_interval_tree_ - .ChunksOverlappingInTime(inclusive_start_time, end_time) - .size() + - num_additional_copies; + int64_t num_prefetches = prefetch_interval_tree_.NumChunksOverlappingInTime( + inclusive_start_time, end_time) + + num_additional_copies; return num_prefetches >= options_.max_outstanding_prefetches + extra_async_copy_limit; } else { - int64_t num_evictions = - eviction_interval_tree_ - .ChunksOverlappingInTime(inclusive_start_time, end_time) - .size() + - num_additional_copies; + int64_t num_evictions = eviction_interval_tree_.NumChunksOverlappingInTime( + inclusive_start_time, end_time) + + num_additional_copies; return num_evictions >= options_.max_outstanding_evictions + extra_async_copy_limit; }