Skip to content

Commit

Permalink
buffer init fix and gpu_hlo_runner test
Browse files Browse the repository at this point in the history
  • Loading branch information
pemeliya committed Oct 28, 2024
1 parent f002ae4 commit 8ec8571
Show file tree
Hide file tree
Showing 5 changed files with 179 additions and 8 deletions.
16 changes: 11 additions & 5 deletions xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -4351,10 +4351,15 @@ xla_cc_test(
],
)

cuda_library(
gpu_kernel_library(
name = "stream_executor_util_kernel",
srcs = if_cuda_is_configured(["stream_executor_util_kernel.cu.cc"]),
deps = ["@local_config_cuda//cuda:cuda_headers"],
srcs = ["stream_executor_util_kernel.cu.cc"],
tags = ["gpu"],
deps = if_cuda_is_configured([
"@local_config_cuda//cuda:cuda_headers",
]) + if_rocm_is_configured([
"@local_config_rocm//rocm:rocm_headers",
]),
)

cc_library(
Expand All @@ -4366,7 +4371,6 @@ cc_library(
deps = [
":cublas_cudnn",
":launch_dimensions",
":stream_executor_util_kernel",
"//xla:autotuning_proto_cc",
"//xla:shape_util",
"//xla:statusor",
Expand Down Expand Up @@ -4394,7 +4398,9 @@ cc_library(
"@tsl//tsl/platform:ml_dtypes",
"@tsl//tsl/platform:status",
"@tsl//tsl/platform:statusor",
],
] + if_gpu_is_configured([
":stream_executor_util_kernel",
]),
)

xla_cc_test(
Expand Down
6 changes: 5 additions & 1 deletion xla/service/gpu/kernels/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,11 @@ gpu_kernel_library(
"//xla:types",
"//xla/stream_executor/gpu:gpu_types_header",
"@tsl//tsl/lib/math:math_util",
],
] + if_cuda_is_configured([
"@local_config_cuda//cuda:cuda_headers",
]) + if_rocm_is_configured([
"@local_config_rocm//rocm:rocm_headers",
]),
)

xla_cc_test(
Expand Down
2 changes: 0 additions & 2 deletions xla/service/gpu/stream_executor_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,6 @@ static void InitializeTypedBuffer(se::Stream* stream,
// Nothing more to do
return;
}
#ifdef GOOGLE_CUDA
// Repeat the host_buffer_size elements at the start of `buf` to the end
CHECK_EQ(elements_to_fill, buffer.size() / sizeof(T) - host_buffer_size);
se::StreamExecutor* executor = stream->parent();
Expand All @@ -504,7 +503,6 @@ static void InitializeTypedBuffer(se::Stream* stream,
se::BlockDim(blocks_per_grid, 1, 1), *kernel,
buffer, host_buffer_bytes,
static_cast<int64_t>(buffer.size())));
#endif
}

void InitializeBuffer(se::Stream* stream, PrimitiveType buffer_type,
Expand Down
33 changes: 33 additions & 0 deletions xla/service/gpu/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,39 @@ xla_cc_test(
],
)

xla_test(
name = "gpu_hlo_runner_test",
srcs = ["gpu_hlo_runner_test.cc"],
backends = ["gpu"],
deps = [
":gpu_codegen_test",
"//xla:error_spec",
"//xla:test",
"//xla:xla_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/service:buffer_assignment",
"//xla/service:executable",
"//xla/service:hlo_module_config",
"//xla/service:hlo_pass",
"//xla/service:pattern_matcher",
"//xla/service:pattern_matcher_gmock",
"//xla/service/gpu:gemm_rewriter",
"//xla/service/gpu:gpu_executable",
"//xla/stream_executor:device_description",
"//xla/stream_executor:device_memory_allocator",
"//xla/stream_executor:stream_executor_memory_allocator",
"//xla/tests:filecheck",
"//xla/tests:verified_hlo_module",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@tsl//tsl/lib/core:status_test_util",
"@tsl//tsl/platform:statusor",
"@tsl//tsl/platform:test_main",
]
)

xla_test(
name = "gemm_rewrite_test",
srcs = ["gemm_rewrite_test.cc"],
Expand Down
130 changes: 130 additions & 0 deletions xla/service/gpu/tests/gpu_hlo_runner_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
/* Copyright 2022 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <fstream>
#include <sstream>
#include "xla/error_spec.h"
#include "xla/literal_comparison.h"
#include "xla/service/custom_call_target_registry.h"
#include "xla/service/gpu/tests/gpu_codegen_test.h"
#include "xla/tests/hlo_test_base.h"
#include "xla/tests/test_utils.h"

namespace xla {
namespace gpu {

template <class T>
std::vector<T*> MakePointerVector(std::vector<T>& input_vec) {
std::vector<T*> output_pointers;
output_pointers.reserve(input_vec.size());
for (auto& input : input_vec) {
output_pointers.push_back(&input);
}
return output_pointers;
}


class HloRunnerTest : public GpuCodegenTest {};

TEST_F(HloRunnerTest, RunSingle) {

std::ifstream ifs("input.hlo");
ASSERT_TRUE(ifs.good());

std::stringstream buffer;
buffer << ifs.rdbuf();

HloModuleConfig config = GetModuleConfigForTest();
#if 1
//config.set_num_partitions(8);

TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(buffer.str(),
config));

auto ref_module = module->Clone();
TF_ASSERT_OK_AND_ASSIGN(auto exec, test_runner_.CreateExecutable(std::move(module), true));

VLOG(0) << "Creating fake args..";
TF_ASSERT_OK_AND_ASSIGN(auto fake_arguments, xla::MakeFakeArguments(ref_module.get(),
true, /*pseudo-random*/
false /* use large range*/));
auto arg_ptrs = MakePointerVector<xla::Literal>(fake_arguments);

auto& ref_runner = HloTestBase::reference_runner_;
TF_ASSERT_OK_AND_ASSIGN(
auto ref_exec, ref_runner.CreateExecutable(std::move(ref_module), true));

// TF_ASSERT_OK_AND_ASSIGN(auto truth,
// ReadLiteralFromProto("/tf/xla/expected.pb"));
// TF_ASSERT_OK_AND_ASSIGN(auto truth,
// ref_runner.ExecuteWithExecutable(ref_exec.get(), arg_ptrs, nullptr));
// WriteLiteralToTempFile(truth, "expected");
//VLOG(0) << "Got expected literal from file.. running test";

TF_ASSERT_OK_AND_ASSIGN(
auto test_res, test_runner_.ExecuteWithExecutable(exec.get(), arg_ptrs));

VLOG(0) << "Running reference exec..";
TF_ASSERT_OK_AND_ASSIGN(
auto truth, ref_runner.ExecuteWithExecutable(ref_exec.get(), arg_ptrs));

ErrorSpec error_spec{1e-2, 1e-3};
//ErrorSpec error_spec(1e-5 /*abs*/, 1e-5 /*rel*/);
ASSERT_EQ(literal_comparison::Near(/*expected=*/truth,
/*actual=*/test_res,
/*error=*/error_spec,
/*detailed_message=*/true, {}), absl::OkStatus());

// EXPECT_TRUE(RunAndCompare(std::move(module),
// // absl::Span< xla::Literal * const>(arg_ptrs.data(), arg_ptrs.size()), error_spec));
#else
int NumReplicas = 8, NumParts = 1;
config.set_replica_count(NumReplicas);
config.set_num_partitions(NumParts);
TF_ASSERT_OK_AND_ASSIGN(
auto module, ParseAndReturnVerifiedModule(buffer.str(), config));
DeviceAssignment assn(/*replica_count=*/NumReplicas,
/*computation_count=*/NumParts);
for (int64_t i = 0, k = 0; i < NumReplicas; i++)
for (int64_t j = 0; j < NumParts; j++) {
assn(i, j) = k++;
}
auto fake_arguments = xla::MakeFakeArguments(
module.get(),
true, /*pseudo-random*/
false /* use large range*/).ValueOrDie();
TF_ASSERT_OK_AND_ASSIGN(auto exec,
test_runner_.CreateExecutable(std::move(module), true));
for(int i = 0; i < 10; i++) {
VLOG(0) << "Running iteration #" << i;
TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
HloTestBase::ExecuteReplicated(
[&](int64_t){ return exec.get(); },
[&fake_arguments](int64_t replica_id)
{ return fake_arguments.size(); },
[&fake_arguments](int64_t replica_id, int64_t idx)
{ return &fake_arguments[idx]; },
NumReplicas, false /*run hlo*/, &assn));
ASSERT_EQ(results.size(), NumReplicas);
}
#endif
}

} // namespace gpu
} // namespace xla

0 comments on commit 8ec8571

Please sign in to comment.