From d85be10681263d74e1dd13ce0da6c06e06242f25 Mon Sep 17 00:00:00 2001 From: Mason Chang Date: Thu, 12 Dec 2024 16:31:51 -0800 Subject: [PATCH] Switch multihost runner to public XLA:GPU target PiperOrigin-RevId: 705668066 --- xla/tools/multihost_hlo_runner/BUILD | 2 +- xla/tools/multihost_hlo_runner/create_client.cc | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/xla/tools/multihost_hlo_runner/BUILD b/xla/tools/multihost_hlo_runner/BUILD index 144f4bdaa847a..66a854d84711e 100644 --- a/xla/tools/multihost_hlo_runner/BUILD +++ b/xla/tools/multihost_hlo_runner/BUILD @@ -111,10 +111,10 @@ cc_library( "//xla/pjrt/distributed:client", "//xla/pjrt/distributed:key_value_store_interface", "//xla/pjrt/distributed:service", - "//xla/pjrt/gpu:se_gpu_pjrt_client", "//xla/pjrt/plugin/xla_cpu:cpu_client_options", "//xla/pjrt/plugin/xla_cpu:xla_cpu_pjrt_client", "//xla/pjrt/plugin/xla_gpu:xla_gpu_client_options", + "//xla/pjrt/plugin/xla_gpu:xla_gpu_pjrt_client", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", diff --git a/xla/tools/multihost_hlo_runner/create_client.cc b/xla/tools/multihost_hlo_runner/create_client.cc index a1c3bb027b5a7..822766ff392ab 100644 --- a/xla/tools/multihost_hlo_runner/create_client.cc +++ b/xla/tools/multihost_hlo_runner/create_client.cc @@ -27,12 +27,12 @@ limitations under the License. #include "xla/pjrt/distributed/client.h" #include "xla/pjrt/distributed/distributed.h" #include "xla/pjrt/distributed/service.h" -#include "xla/pjrt/gpu/se_gpu_pjrt_client.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_compiler.h" #include "xla/pjrt/plugin/xla_cpu/cpu_client_options.h" #include "xla/pjrt/plugin/xla_cpu/xla_cpu_pjrt_client.h" #include "xla/pjrt/plugin/xla_gpu/xla_gpu_client_options.h" +#include "xla/pjrt/plugin/xla_gpu/xla_gpu_pjrt_client.h" #include "xla/status_macros.h" #include "xla/xla.pb.h" #include "tsl/platform/status.h" @@ -113,7 +113,7 @@ absl::StatusOr> CreateGpuClient( return absl::InvalidArgumentError( "Node id is expected to be in range [0, num_nodes)"); } - return GetStreamExecutorGpuClient(options); + return xla::GetXlaPjrtGpuClient(options); } absl::StatusOr> CreateMockGpuClient(int num_nodes) {