Skip to content

Commit

Permalink
Use absl::Mutex::Await() instead of absl::CondVar::Wait() in XLA.
Browse files Browse the repository at this point in the history
This CL replaces all uses of absl::CondVar::Wait() in XLA with absl::Mutex::Await().

PiperOrigin-RevId: 705759762
  • Loading branch information
majnemer authored and Google-ML-Automation committed Dec 13, 2024
1 parent 540c69c commit e650c3d
Show file tree
Hide file tree
Showing 8 changed files with 62 additions and 71 deletions.
10 changes: 5 additions & 5 deletions xla/backends/cpu/codegen/jit_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ limitations under the License.
#include <string_view>
#include <utility>

#include "absl/base/thread_annotations.h"
#include "absl/container/flat_hash_map.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
Expand Down Expand Up @@ -334,15 +335,14 @@ void JitCompiler::TaskDispatcher::dispatch(

absl::MutexLock lock(&mu_);
--num_dispatched_tasks_;
cv_.SignalAll();
});
}

void JitCompiler::TaskDispatcher::shutdown() {
absl::MutexLock lock(&mu_);
while (num_dispatched_tasks_ > 0) {
cv_.Wait(&mu_);
}
auto all_tasks_finished = [this]() ABSL_SHARED_LOCKS_REQUIRED(mu_) {
return num_dispatched_tasks_ == 0;
};
absl::MutexLock lock(&mu_, absl::Condition(&all_tasks_finished));
}

JitCompiler::CompiledFunctionLibrary::CompiledFunctionLibrary(
Expand Down
1 change: 0 additions & 1 deletion xla/backends/cpu/codegen/jit_compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,6 @@ class JitCompiler {
TaskRunner task_runner_;

absl::Mutex mu_;
absl::CondVar cv_;
size_t num_dispatched_tasks_ ABSL_GUARDED_BY(mu_) = 0;
};

Expand Down
17 changes: 5 additions & 12 deletions xla/service/cpu/xfeed_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License.
#include <cstdint>
#include <utility>

#include "absl/base/thread_annotations.h"
#include "absl/synchronization/mutex.h"
#include "absl/types/span.h"
#include "xla/shape.h"
Expand All @@ -31,27 +32,19 @@ namespace runtime {
void XfeedQueueManager::EnqueueBuffersAtomically(
absl::Span<XfeedBuffer* const> buffers) {
absl::MutexLock l(&mu_);
bool was_empty = enqueued_buffers_.empty();
for (XfeedBuffer* b : buffers) {
VLOG(3) << "Enqueueing " << queue_name_ << " buffer (of " << buffers.size()
<< " buffers) with length: " << b->length();
enqueued_buffers_.push_back(b);
}
if (was_empty && !buffers.empty()) {
// This has the potential to suffer from the notified thread
// immediately trying and failing to acquire mu_, but seems
// preferable to the alternative of notifying outside the lock
// on every enqueue.
cv_.Signal();
}
}

XfeedBuffer* XfeedQueueManager::BlockingDequeueBuffer() {
absl::MutexLock l(&mu_);
VLOG(3) << "Waiting for an available buffer.";
while (enqueued_buffers_.empty()) {
cv_.Wait(&mu_);
}
auto available_buffer = [this]() ABSL_SHARED_LOCKS_REQUIRED(mu_) {
return !enqueued_buffers_.empty();
};
absl::MutexLock l(&mu_, absl::Condition(&available_buffer));
VLOG(3) << "A buffer is available!";
CHECK(current_buffer_ == nullptr);
current_buffer_ = enqueued_buffers_.front();
Expand Down
4 changes: 0 additions & 4 deletions xla/service/cpu/xfeed_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,6 @@ class XfeedQueueManager {

absl::Mutex mu_;

// Condition variable that is signaled every time a buffer is
// enqueued to an empty queue.
absl::CondVar cv_;

// XfeedBuffer* queue contents are not owned, but buffer->Done must
// be called when the buffer is no longer needed by the runtime.
std::deque<XfeedBuffer*> enqueued_buffers_;
Expand Down
37 changes: 17 additions & 20 deletions xla/service/gpu/xfeed_queue.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ class XfeedQueue {
void EnqueueDestination(BufferType buffers) {
absl::MutexLock l(&mu_);
enqueued_buffers_.push_back(std::move(buffers));
enqueue_cv_.Signal();

EnqueueHook();
}
Expand All @@ -57,10 +56,8 @@ class XfeedQueue {
bool became_empty;
BufferType current_buffer;
{
absl::MutexLock l(&mu_);
while (enqueued_buffers_.empty()) {
enqueue_cv_.Wait(&mu_);
}
absl::MutexLock l(&mu_,
absl::Condition(this, &XfeedQueue::IsBufferEnqueued));
current_buffer = std::move(enqueued_buffers_.front());
enqueued_buffers_.pop_front();
DequeueHook();
Expand Down Expand Up @@ -94,8 +91,10 @@ class XfeedQueue {
std::deque<BufferType> enqueued_buffers_ ABSL_GUARDED_BY(mu_);

private:
// Condition variable that is signaled every time a buffer is enqueued.
absl::CondVar enqueue_cv_;
// Returns true if there is a buffer in the queue.
bool IsBufferEnqueued() const ABSL_SHARED_LOCKS_REQUIRED(mu_) {
return !enqueued_buffers_.empty();
}

// List of callbacks which will be called when 'enqueued_buffers_' becomes
// empty.
Expand All @@ -122,14 +121,9 @@ class BlockingXfeedQueue : public XfeedQueue<BufferType> {
: max_pending_xfeeds_(max_pending_xfeeds) {}

void BlockUntilEnqueueSlotAvailable() {
absl::MutexLock l{&this->mu_};
while (pending_buffers_ + this->enqueued_buffers_.size() >=
max_pending_xfeeds_) {
VLOG(2) << "Capacity "
<< (pending_buffers_ + this->enqueued_buffers_.size())
<< " >= max capacity " << max_pending_xfeeds_;
dequeue_cv_.Wait(&this->mu_);
}
absl::MutexLock l{
&this->mu_,
absl::Condition(this, &BlockingXfeedQueue::IsEnqueueSlotAvailable)};

pending_buffers_++;
}
Expand All @@ -139,15 +133,18 @@ class BlockingXfeedQueue : public XfeedQueue<BufferType> {
pending_buffers_--;
}

void DequeueHook() ABSL_EXCLUSIVE_LOCKS_REQUIRED(this->mu_) override {
dequeue_cv_.Signal();
}
void DequeueHook() ABSL_EXCLUSIVE_LOCKS_REQUIRED(this->mu_) override {}

private:
const int max_pending_xfeeds_;

// Condition variable that is signaled every time a buffer is dequeued.
absl::CondVar dequeue_cv_;
bool IsEnqueueSlotAvailable() const ABSL_SHARED_LOCKS_REQUIRED(this->mu_) {
VLOG(2) << "Capacity "
<< (pending_buffers_ + this->enqueued_buffers_.size())
<< " >= max capacity " << max_pending_xfeeds_;
return pending_buffers_ + this->enqueued_buffers_.size() <
max_pending_xfeeds_;
}

// Keeps track of the number of buffers reserved but not added to
// enqueued_buffers_.
Expand Down
6 changes: 3 additions & 3 deletions xla/tsl/platform/default/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -482,10 +482,10 @@ cc_library(
"nobuilder",
],
deps = [
"@com_google_absl//absl/memory",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/synchronization",
"@tsl//tsl/platform:env",
"@tsl//tsl/platform:mutex",
"@tsl//tsl/platform:notification",
"@tsl//tsl/platform:platform_port",
],
)
Expand Down
29 changes: 14 additions & 15 deletions xla/tsl/platform/default/unbounded_work_queue.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,25 @@ limitations under the License.

#include "xla/tsl/platform/default/unbounded_work_queue.h"

#include "absl/memory/memory.h"
#include <utility>

#include "absl/strings/string_view.h"
#include "absl/synchronization/mutex.h"
#include "tsl/platform/env.h"
#include "tsl/platform/mutex.h"
#include "tsl/platform/numa.h"

namespace tsl {

UnboundedWorkQueue::UnboundedWorkQueue(Env* env, const string& thread_name,
UnboundedWorkQueue::UnboundedWorkQueue(Env* env, absl::string_view thread_name,
const ThreadOptions& thread_options)
: env_(env), thread_name_(thread_name), thread_options_(thread_options) {}

UnboundedWorkQueue::~UnboundedWorkQueue() {
{
mutex_lock l(work_queue_mu_);
absl::MutexLock l(&work_queue_mu_);
// Wake up all `PooledThreadFunc` threads and cause them to terminate before
// joining them when `threads_` is cleared.
cancelled_ = true;
work_queue_cv_.notify_all();
if (!work_queue_.empty()) {
LOG(ERROR) << "UnboundedWorkQueue named \"" << thread_name_ << "\" was "
<< "deleted with pending work in its queue. This may indicate "
Expand All @@ -41,7 +42,7 @@ UnboundedWorkQueue::~UnboundedWorkQueue() {
}

{
mutex_lock l(thread_pool_mu_);
absl::MutexLock l(&thread_pool_mu_);
// Clear the list of pooled threads, which will eventually terminate due to
// the previous notification.
//
Expand All @@ -55,9 +56,8 @@ UnboundedWorkQueue::~UnboundedWorkQueue() {
void UnboundedWorkQueue::Schedule(WorkFunction fn) {
// Enqueue a work item for the new thread's function, and wake up a
// cached thread to process it.
mutex_lock l(work_queue_mu_);
absl::MutexLock l(&work_queue_mu_);
work_queue_.push_back(std::move(fn));
work_queue_cv_.notify_one();
// NOTE: The queue may be non-empty, so we must account for queued work when
// considering how many threads are free.
if (work_queue_.size() > num_idle_threads_) {
Expand All @@ -67,7 +67,7 @@ void UnboundedWorkQueue::Schedule(WorkFunction fn) {
Thread* new_thread =
env_->StartThread({}, thread_name_, [this]() { PooledThreadFunc(); });

mutex_lock l(thread_pool_mu_);
absl::MutexLock l(&thread_pool_mu_);
thread_pool_.emplace_back(new_thread);
}
}
Expand All @@ -81,13 +81,12 @@ void UnboundedWorkQueue::PooledThreadFunc() {
while (true) {
WorkFunction fn;
{
mutex_lock l(work_queue_mu_);
absl::MutexLock l(&work_queue_mu_);
++num_idle_threads_;
while (!cancelled_ && work_queue_.empty()) {
// Wait for a new work function to be submitted, or the cache to be
// destroyed.
work_queue_cv_.wait(l);
}
// Wait for a new work function to be submitted, or the cache to be
// destroyed.
work_queue_mu_.Await(
absl::Condition(this, &UnboundedWorkQueue::HasWorkOrIsCancelled));
if (cancelled_) {
return;
}
Expand Down
29 changes: 18 additions & 11 deletions xla/tsl/platform/default/unbounded_work_queue.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,17 @@ limitations under the License.
#ifndef XLA_TSL_PLATFORM_DEFAULT_UNBOUNDED_WORK_QUEUE_H_
#define XLA_TSL_PLATFORM_DEFAULT_UNBOUNDED_WORK_QUEUE_H_

#include <cstddef>
#include <deque>
#include <functional>
#include <memory>
#include <string>
#include <vector>

#include "absl/base/thread_annotations.h"
#include "absl/strings/string_view.h"
#include "absl/synchronization/mutex.h"
#include "tsl/platform/env.h"
#include "tsl/platform/mutex.h"
#include "tsl/platform/notification.h"

namespace tsl {

Expand All @@ -36,7 +40,7 @@ namespace tsl {
// fragmentation that can result from excessive thread creation.
class UnboundedWorkQueue {
public:
UnboundedWorkQueue(Env* env, const string& thread_name,
UnboundedWorkQueue(Env* env, absl::string_view thread_name,
const ThreadOptions& thread_options = {});
~UnboundedWorkQueue();

Expand All @@ -50,17 +54,20 @@ class UnboundedWorkQueue {
private:
void PooledThreadFunc();

bool HasWorkOrIsCancelled() const ABSL_SHARED_LOCKS_REQUIRED(work_queue_mu_) {
return !work_queue_.empty() || cancelled_;
}

Env* const env_; // Not owned.
const string thread_name_;
const std::string thread_name_;
const ThreadOptions thread_options_;
mutex work_queue_mu_;
condition_variable work_queue_cv_ TF_GUARDED_BY(work_queue_mu_);
size_t num_idle_threads_ TF_GUARDED_BY(work_queue_mu_) = 0;
bool cancelled_ TF_GUARDED_BY(work_queue_mu_) = false;
std::deque<WorkFunction> work_queue_ TF_GUARDED_BY(work_queue_mu_);
mutex thread_pool_mu_;
absl::Mutex work_queue_mu_;
size_t num_idle_threads_ ABSL_GUARDED_BY(work_queue_mu_) = 0;
bool cancelled_ ABSL_GUARDED_BY(work_queue_mu_) = false;
std::deque<WorkFunction> work_queue_ ABSL_GUARDED_BY(work_queue_mu_);
absl::Mutex thread_pool_mu_;
std::vector<std::unique_ptr<Thread>> thread_pool_
TF_GUARDED_BY(thread_pool_mu_);
ABSL_GUARDED_BY(thread_pool_mu_);
};

} // namespace tsl
Expand Down

0 comments on commit e650c3d

Please sign in to comment.