Skip to content

Commit

Permalink
[xla:cpu] Add parallel loop runner
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 706005631
  • Loading branch information
ezhulenev authored and Google-ML-Automation committed Dec 13, 2024
1 parent 9ff0d7e commit 198dc23
Show file tree
Hide file tree
Showing 4 changed files with 306 additions and 0 deletions.
29 changes: 29 additions & 0 deletions xla/backends/cpu/runtime/xnnpack/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,35 @@ xla_cc_test(
],
)

cc_library(
name = "parallel_loop_runner",
srcs = ["parallel_loop_runner.cc"],
hdrs = ["parallel_loop_runner.h"],
deps = [
"//xla/tsl/concurrency:async_value",
"//xla/tsl/lib/math:math_util",
"@com_google_absl//absl/base:core_headers",
"@eigen_archive//:eigen3",
"@tsl//tsl/platform:logging",
],
)

xla_cc_test(
name = "parallel_loop_runner_test",
srcs = ["parallel_loop_runner_test.cc"],
deps = [
":parallel_loop_runner",
"//xla/tsl/concurrency:async_value",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/synchronization",
"@eigen_archive//:eigen3",
"@tsl//tsl/platform:env",
"@tsl//tsl/platform:test",
"@tsl//tsl/platform:test_benchmark",
"@tsl//tsl/platform:test_main",
],
)

cc_library(
name = "xnn_interop",
hdrs = ["xnn_interop.h"],
Expand Down
125 changes: 125 additions & 0 deletions xla/backends/cpu/runtime/xnnpack/parallel_loop_runner.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/backends/cpu/runtime/xnnpack/parallel_loop_runner.h"

#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <functional>
#include <utility>

#include "absl/base/optimization.h"
#include "xla/tsl/concurrency/async_value_ref.h"
#include "xla/tsl/concurrency/chain.h"
#include "xla/tsl/lib/math/math_util.h"
#include "tsl/platform/logging.h"

#define EIGEN_USE_THREADS
#include "unsupported/Eigen/CXX11/Tensor"

namespace xla::cpu {

using Task = std::function<void(size_t task_index)>;

// Returns non-reference-counted async value ref in constructed state.
//
// Returned async value is a per-process singleton stored in a storage with a
// static duration, and can be safely compared using pointer equality.
static tsl::AsyncValueRef<tsl::Chain> OkDoneEventSingleton() {
static tsl::AsyncValueOwningRef<tsl::Chain>* singleton = [] {
auto* storage = new tsl::internal::AsyncValueStorage<tsl::Chain>();
return new tsl::AsyncValueOwningRef<tsl::Chain>(
tsl::MakeAvailableAsyncValueRef<tsl::Chain>(*storage));
}();
return singleton->AsRef();
}

// Schedules tasks in the [start_index, end_index) range into the Eigen thread
// pool using recursive work splitting. Executes the `start_index` task in the
// caller thread.
static void ScheduleRange(tsl::CountDownAsyncValueRef<tsl::Chain> count_down,
Eigen::ThreadPoolDevice* device, size_t start_index,
size_t end_index, Task task) {
CHECK_LT(start_index, end_index) << "Invalid task index range"; // Crash OK
while (end_index - start_index > 1) {
uint64_t mid_index = (start_index + end_index) / 2;
device->enqueueNoNotification([device, mid_index, end_index, task,
count_down] {
ScheduleRange(std::move(count_down), device, mid_index, end_index, task);
});
end_index = mid_index;
}
task(start_index);
count_down.CountDown();
}

ParallelLoopRunner::ParallelLoopRunner(Eigen::ThreadPoolDevice* device)
: done_event_(OkDoneEventSingleton()), device_(device) {}

tsl::AsyncValueRef<tsl::Chain> ParallelLoopRunner::TakeDoneEvent(
ParallelLoopRunner&& runner) {
return std::move(runner.done_event_);
}

void ParallelLoopRunner::Parallelize(size_t range, size_t tile, Task1D task) {
DCHECK(done_event_) << "Parallel loop runner is in moved-from state";

size_t num_tasks = tsl::MathUtil::CeilOfRatio(range, tile);
DCHECK_GT(num_tasks, 0) << "Expected at least one task";

// Fast path for the degenerate parallel loop with single task.
if (ABSL_PREDICT_TRUE(num_tasks == 1)) {
DCHECK_EQ(range, tile) << "Expected range to be equal to tile";

if (ABSL_PREDICT_TRUE(done_event_.IsConcrete())) {
// If done event is already available, execute the task immediately in the
// caller thread. In this case we don't need to overwrite the done event,
// because the existing one will correctly represent the state of the
// parallel loop runner (all scheduled loops are ready).
task(0, range);

} else {
// If done event is not available, we have to overwrite it with a new one
// that will be set to concrete state after the task is executed.
auto done_event = tsl::MakeConstructedAsyncValueRef<tsl::Chain>();
done_event_.AndThen([range, done_event, task = std::move(task)] {
task(0, range);
done_event.SetStateConcrete();
});
done_event_ = std::move(done_event);
}

return;
}

// Schedule `num_tasks` into the underlying thread pool when done event
// becomes available.
tsl::CountDownAsyncValueRef<tsl::Chain> count_down(num_tasks);
auto done_event = count_down.AsRef();

done_event_.AndThen([this, num_tasks, range, tile, task = std::move(task),
count_down = std::move(count_down)] {
ScheduleRange(std::move(count_down), device_, 0, num_tasks,
[range, tile, task = std::move(task)](size_t task_index) {
size_t offset = task_index * tile;
size_t extent = std::min(range - offset, tile);
task(offset, extent);
});
});
done_event_ = std::move(done_event);
}

} // namespace xla::cpu
74 changes: 74 additions & 0 deletions xla/backends/cpu/runtime/xnnpack/parallel_loop_runner.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef XLA_BACKENDS_CPU_RUNTIME_XNNPACK_PARALLEL_LOOP_RUNNER_H_
#define XLA_BACKENDS_CPU_RUNTIME_XNNPACK_PARALLEL_LOOP_RUNNER_H_

#include <cstddef>
#include <functional>

#include "xla/tsl/concurrency/async_value_ref.h"
#include "xla/tsl/concurrency/chain.h"

namespace Eigen {
struct ThreadPoolDevice;
} // namespace Eigen

namespace xla::cpu {

// Parallel loop runner uses underlying Eigen ThreadPoolDevice to execute
// parallel loops providing implicit synchronization: the next parallel loop
// starts execution only after all tasks from the previous loop are completed.
//
// Scheduled parallel loops execute asynchronously without blocking the caller
// thread. It is the user's responsibility to ensure that all values captured by
// the task are valid until the task is completed.
//
// Parallel loop runner is an implementation of the `pthreadpool` API adaptor
// for XLA:CPU runtime.
//
// WARNING: ParallelLoopRunner is not thread-safe, and must be externally
// synchronized by the user.
class ParallelLoopRunner {
public:
explicit ParallelLoopRunner(Eigen::ThreadPoolDevice* device);

// Takes ownership of the runner and returns a done event. After the done
// event is transferred to the caller, it is illegal to schedule more parallel
// loops on the moved-from runner.
static tsl::AsyncValueRef<tsl::Chain> TakeDoneEvent(
ParallelLoopRunner&& runner);

using Task1D = std::function<void(size_t offset, size_t extent)>;

// This function implements a parallel version of a following loop:
//
// for (size_t i = 0; i < range; i += tile)
// task(i, std::min(range - i, tile));
void Parallelize(size_t range, size_t tile, Task1D task);

tsl::AsyncValueRef<tsl::Chain> done_event() const { return done_event_; }
Eigen::ThreadPoolDevice* device() const { return device_; }

private:
// Async value that signals completion of the last scheduled parallel loop.
tsl::AsyncValueRef<tsl::Chain> done_event_;

Eigen::ThreadPoolDevice* device_;
};

} // namespace xla::cpu

#endif // XLA_BACKENDS_CPU_RUNTIME_XNNPACK_PARALLEL_LOOP_RUNNER_H_
78 changes: 78 additions & 0 deletions xla/backends/cpu/runtime/xnnpack/parallel_loop_runner_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/backends/cpu/runtime/xnnpack/parallel_loop_runner.h"

#include <cstddef>
#include <cstdint>
#include <utility>
#include <vector>

#include "absl/algorithm/container.h"
#include "xla/tsl/concurrency/async_value_ref.h"
#include "tsl/platform/env.h"
#include "tsl/platform/test.h"
#include "tsl/platform/test_benchmark.h"
#include "tsl/platform/threadpool.h"

#define EIGEN_USE_THREADS
#include "unsupported/Eigen/CXX11/Tensor"

namespace xla::cpu {
namespace {

TEST(ParallelLoopRunnerTest, BackToBack1DLoops) {
tsl::thread::ThreadPool threads(tsl::Env::Default(), "test", 8);
Eigen::ThreadPoolDevice device(threads.AsEigenThreadPool(),
threads.NumThreads());
ParallelLoopRunner runner(&device);

std::vector<int32_t> data(1024);
auto inc_range = [&](size_t offset, size_t extent) {
for (size_t i = offset; i < offset + extent; ++i) {
data[i] += 1;
}
};

runner.Parallelize(1024, 1, inc_range);
runner.Parallelize(1024, 2, inc_range);
runner.Parallelize(1024, 3, inc_range);
runner.Parallelize(1024, 4, inc_range);
runner.Parallelize(1024, 5, inc_range);

tsl::BlockUntilReady(ParallelLoopRunner::TakeDoneEvent(std::move(runner)));
ASSERT_TRUE(absl::c_all_of(data, [](int32_t value) { return value == 5; }));
}

//===----------------------------------------------------------------------===//
// Performance benchmarks.
//===----------------------------------------------------------------------===//

static void BM_SingleTask1DLoop(benchmark::State& state) {
tsl::thread::ThreadPool threads(tsl::Env::Default(), "test", 8);
Eigen::ThreadPoolDevice device(threads.AsEigenThreadPool(),
threads.NumThreads());
ParallelLoopRunner runner(&device);

for (auto _ : state) {
runner.Parallelize(1, 1, [](size_t, size_t) {});
tsl::BlockUntilReady(runner.done_event());
}
}

BENCHMARK(BM_SingleTask1DLoop);

} // namespace
} // namespace xla::cpu

0 comments on commit 198dc23

Please sign in to comment.