diff --git a/xla/backends/cpu/codegen/jit_compiler.cc b/xla/backends/cpu/codegen/jit_compiler.cc index 03dcfad9033b8..2851caaeb7b6a 100644 --- a/xla/backends/cpu/codegen/jit_compiler.cc +++ b/xla/backends/cpu/codegen/jit_compiler.cc @@ -23,6 +23,7 @@ limitations under the License. #include #include +#include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -334,15 +335,14 @@ void JitCompiler::TaskDispatcher::dispatch( absl::MutexLock lock(&mu_); --num_dispatched_tasks_; - cv_.SignalAll(); }); } void JitCompiler::TaskDispatcher::shutdown() { - absl::MutexLock lock(&mu_); - while (num_dispatched_tasks_ > 0) { - cv_.Wait(&mu_); - } + auto all_tasks_finished = [this]() ABSL_SHARED_LOCKS_REQUIRED(mu_) { + return num_dispatched_tasks_ == 0; + }; + absl::MutexLock lock(&mu_, absl::Condition(&all_tasks_finished)); } JitCompiler::CompiledFunctionLibrary::CompiledFunctionLibrary( diff --git a/xla/backends/cpu/codegen/jit_compiler.h b/xla/backends/cpu/codegen/jit_compiler.h index 771e65380780e..8d4aabac58cdb 100644 --- a/xla/backends/cpu/codegen/jit_compiler.h +++ b/xla/backends/cpu/codegen/jit_compiler.h @@ -157,7 +157,6 @@ class JitCompiler { TaskRunner task_runner_; absl::Mutex mu_; - absl::CondVar cv_; size_t num_dispatched_tasks_ ABSL_GUARDED_BY(mu_) = 0; }; diff --git a/xla/service/cpu/xfeed_manager.cc b/xla/service/cpu/xfeed_manager.cc index 36f2c9c7c308a..9f55980ae41ab 100644 --- a/xla/service/cpu/xfeed_manager.cc +++ b/xla/service/cpu/xfeed_manager.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/base/thread_annotations.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "xla/shape.h" @@ -31,27 +32,19 @@ namespace runtime { void XfeedQueueManager::EnqueueBuffersAtomically( absl::Span buffers) { absl::MutexLock l(&mu_); - bool was_empty = enqueued_buffers_.empty(); for (XfeedBuffer* b : buffers) { VLOG(3) << "Enqueueing " << queue_name_ << " buffer (of " << buffers.size() << " buffers) with length: " << b->length(); enqueued_buffers_.push_back(b); } - if (was_empty && !buffers.empty()) { - // This has the potential to suffer from the notified thread - // immediately trying and failing to acquire mu_, but seems - // preferable to the alternative of notifying outside the lock - // on every enqueue. - cv_.Signal(); - } } XfeedBuffer* XfeedQueueManager::BlockingDequeueBuffer() { - absl::MutexLock l(&mu_); VLOG(3) << "Waiting for an available buffer."; - while (enqueued_buffers_.empty()) { - cv_.Wait(&mu_); - } + auto available_buffer = [this]() ABSL_SHARED_LOCKS_REQUIRED(mu_) { + return !enqueued_buffers_.empty(); + }; + absl::MutexLock l(&mu_, absl::Condition(&available_buffer)); VLOG(3) << "A buffer is available!"; CHECK(current_buffer_ == nullptr); current_buffer_ = enqueued_buffers_.front(); diff --git a/xla/service/cpu/xfeed_manager.h b/xla/service/cpu/xfeed_manager.h index 19664ba9f4cba..3dee7629fdc22 100644 --- a/xla/service/cpu/xfeed_manager.h +++ b/xla/service/cpu/xfeed_manager.h @@ -86,10 +86,6 @@ class XfeedQueueManager { absl::Mutex mu_; - // Condition variable that is signaled every time a buffer is - // enqueued to an empty queue. - absl::CondVar cv_; - // XfeedBuffer* queue contents are not owned, but buffer->Done must // be called when the buffer is no longer needed by the runtime. std::deque enqueued_buffers_; diff --git a/xla/service/gpu/xfeed_queue.h b/xla/service/gpu/xfeed_queue.h index 18f63a934a17c..737bc921a2e3e 100644 --- a/xla/service/gpu/xfeed_queue.h +++ b/xla/service/gpu/xfeed_queue.h @@ -42,7 +42,6 @@ class XfeedQueue { void EnqueueDestination(BufferType buffers) { absl::MutexLock l(&mu_); enqueued_buffers_.push_back(std::move(buffers)); - enqueue_cv_.Signal(); EnqueueHook(); } @@ -57,10 +56,8 @@ class XfeedQueue { bool became_empty; BufferType current_buffer; { - absl::MutexLock l(&mu_); - while (enqueued_buffers_.empty()) { - enqueue_cv_.Wait(&mu_); - } + absl::MutexLock l(&mu_, + absl::Condition(this, &XfeedQueue::IsBufferEnqueued)); current_buffer = std::move(enqueued_buffers_.front()); enqueued_buffers_.pop_front(); DequeueHook(); @@ -94,8 +91,10 @@ class XfeedQueue { std::deque enqueued_buffers_ ABSL_GUARDED_BY(mu_); private: - // Condition variable that is signaled every time a buffer is enqueued. - absl::CondVar enqueue_cv_; + // Returns true if there is a buffer in the queue. + bool IsBufferEnqueued() const ABSL_SHARED_LOCKS_REQUIRED(mu_) { + return !enqueued_buffers_.empty(); + } // List of callbacks which will be called when 'enqueued_buffers_' becomes // empty. @@ -122,14 +121,9 @@ class BlockingXfeedQueue : public XfeedQueue { : max_pending_xfeeds_(max_pending_xfeeds) {} void BlockUntilEnqueueSlotAvailable() { - absl::MutexLock l{&this->mu_}; - while (pending_buffers_ + this->enqueued_buffers_.size() >= - max_pending_xfeeds_) { - VLOG(2) << "Capacity " - << (pending_buffers_ + this->enqueued_buffers_.size()) - << " >= max capacity " << max_pending_xfeeds_; - dequeue_cv_.Wait(&this->mu_); - } + absl::MutexLock l{ + &this->mu_, + absl::Condition(this, &BlockingXfeedQueue::IsEnqueueSlotAvailable)}; pending_buffers_++; } @@ -139,15 +133,18 @@ class BlockingXfeedQueue : public XfeedQueue { pending_buffers_--; } - void DequeueHook() ABSL_EXCLUSIVE_LOCKS_REQUIRED(this->mu_) override { - dequeue_cv_.Signal(); - } + void DequeueHook() ABSL_EXCLUSIVE_LOCKS_REQUIRED(this->mu_) override {} private: const int max_pending_xfeeds_; - // Condition variable that is signaled every time a buffer is dequeued. - absl::CondVar dequeue_cv_; + bool IsEnqueueSlotAvailable() const ABSL_SHARED_LOCKS_REQUIRED(this->mu_) { + VLOG(2) << "Capacity " + << (pending_buffers_ + this->enqueued_buffers_.size()) + << " >= max capacity " << max_pending_xfeeds_; + return pending_buffers_ + this->enqueued_buffers_.size() < + max_pending_xfeeds_; + } // Keeps track of the number of buffers reserved but not added to // enqueued_buffers_. diff --git a/xla/tsl/platform/default/BUILD b/xla/tsl/platform/default/BUILD index a784764f82308..a027326f1881f 100644 --- a/xla/tsl/platform/default/BUILD +++ b/xla/tsl/platform/default/BUILD @@ -482,10 +482,10 @@ cc_library( "nobuilder", ], deps = [ - "@com_google_absl//absl/memory", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/synchronization", "@tsl//tsl/platform:env", - "@tsl//tsl/platform:mutex", - "@tsl//tsl/platform:notification", "@tsl//tsl/platform:platform_port", ], ) diff --git a/xla/tsl/platform/default/unbounded_work_queue.cc b/xla/tsl/platform/default/unbounded_work_queue.cc index 818d54435439d..f8a9b055ff819 100644 --- a/xla/tsl/platform/default/unbounded_work_queue.cc +++ b/xla/tsl/platform/default/unbounded_work_queue.cc @@ -15,24 +15,25 @@ limitations under the License. #include "xla/tsl/platform/default/unbounded_work_queue.h" -#include "absl/memory/memory.h" +#include + +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" #include "tsl/platform/env.h" -#include "tsl/platform/mutex.h" #include "tsl/platform/numa.h" namespace tsl { -UnboundedWorkQueue::UnboundedWorkQueue(Env* env, const string& thread_name, +UnboundedWorkQueue::UnboundedWorkQueue(Env* env, absl::string_view thread_name, const ThreadOptions& thread_options) : env_(env), thread_name_(thread_name), thread_options_(thread_options) {} UnboundedWorkQueue::~UnboundedWorkQueue() { { - mutex_lock l(work_queue_mu_); + absl::MutexLock l(&work_queue_mu_); // Wake up all `PooledThreadFunc` threads and cause them to terminate before // joining them when `threads_` is cleared. cancelled_ = true; - work_queue_cv_.notify_all(); if (!work_queue_.empty()) { LOG(ERROR) << "UnboundedWorkQueue named \"" << thread_name_ << "\" was " << "deleted with pending work in its queue. This may indicate " @@ -41,7 +42,7 @@ UnboundedWorkQueue::~UnboundedWorkQueue() { } { - mutex_lock l(thread_pool_mu_); + absl::MutexLock l(&thread_pool_mu_); // Clear the list of pooled threads, which will eventually terminate due to // the previous notification. // @@ -55,9 +56,8 @@ UnboundedWorkQueue::~UnboundedWorkQueue() { void UnboundedWorkQueue::Schedule(WorkFunction fn) { // Enqueue a work item for the new thread's function, and wake up a // cached thread to process it. - mutex_lock l(work_queue_mu_); + absl::MutexLock l(&work_queue_mu_); work_queue_.push_back(std::move(fn)); - work_queue_cv_.notify_one(); // NOTE: The queue may be non-empty, so we must account for queued work when // considering how many threads are free. if (work_queue_.size() > num_idle_threads_) { @@ -67,7 +67,7 @@ void UnboundedWorkQueue::Schedule(WorkFunction fn) { Thread* new_thread = env_->StartThread({}, thread_name_, [this]() { PooledThreadFunc(); }); - mutex_lock l(thread_pool_mu_); + absl::MutexLock l(&thread_pool_mu_); thread_pool_.emplace_back(new_thread); } } @@ -81,13 +81,12 @@ void UnboundedWorkQueue::PooledThreadFunc() { while (true) { WorkFunction fn; { - mutex_lock l(work_queue_mu_); + absl::MutexLock l(&work_queue_mu_); ++num_idle_threads_; - while (!cancelled_ && work_queue_.empty()) { - // Wait for a new work function to be submitted, or the cache to be - // destroyed. - work_queue_cv_.wait(l); - } + // Wait for a new work function to be submitted, or the cache to be + // destroyed. + work_queue_mu_.Await( + absl::Condition(this, &UnboundedWorkQueue::HasWorkOrIsCancelled)); if (cancelled_) { return; } diff --git a/xla/tsl/platform/default/unbounded_work_queue.h b/xla/tsl/platform/default/unbounded_work_queue.h index 401b2b596d350..5a61a4a5373b2 100644 --- a/xla/tsl/platform/default/unbounded_work_queue.h +++ b/xla/tsl/platform/default/unbounded_work_queue.h @@ -15,13 +15,17 @@ limitations under the License. #ifndef XLA_TSL_PLATFORM_DEFAULT_UNBOUNDED_WORK_QUEUE_H_ #define XLA_TSL_PLATFORM_DEFAULT_UNBOUNDED_WORK_QUEUE_H_ +#include #include +#include #include +#include #include +#include "absl/base/thread_annotations.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" #include "tsl/platform/env.h" -#include "tsl/platform/mutex.h" -#include "tsl/platform/notification.h" namespace tsl { @@ -36,7 +40,7 @@ namespace tsl { // fragmentation that can result from excessive thread creation. class UnboundedWorkQueue { public: - UnboundedWorkQueue(Env* env, const string& thread_name, + UnboundedWorkQueue(Env* env, absl::string_view thread_name, const ThreadOptions& thread_options = {}); ~UnboundedWorkQueue(); @@ -50,17 +54,20 @@ class UnboundedWorkQueue { private: void PooledThreadFunc(); + bool HasWorkOrIsCancelled() const ABSL_SHARED_LOCKS_REQUIRED(work_queue_mu_) { + return !work_queue_.empty() || cancelled_; + } + Env* const env_; // Not owned. - const string thread_name_; + const std::string thread_name_; const ThreadOptions thread_options_; - mutex work_queue_mu_; - condition_variable work_queue_cv_ TF_GUARDED_BY(work_queue_mu_); - size_t num_idle_threads_ TF_GUARDED_BY(work_queue_mu_) = 0; - bool cancelled_ TF_GUARDED_BY(work_queue_mu_) = false; - std::deque work_queue_ TF_GUARDED_BY(work_queue_mu_); - mutex thread_pool_mu_; + absl::Mutex work_queue_mu_; + size_t num_idle_threads_ ABSL_GUARDED_BY(work_queue_mu_) = 0; + bool cancelled_ ABSL_GUARDED_BY(work_queue_mu_) = false; + std::deque work_queue_ ABSL_GUARDED_BY(work_queue_mu_); + absl::Mutex thread_pool_mu_; std::vector> thread_pool_ - TF_GUARDED_BY(thread_pool_mu_); + ABSL_GUARDED_BY(thread_pool_mu_); }; } // namespace tsl