Skip to content

Commit

Permalink
Switch multihost runner to public XLA:GPU target
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 705668066
  • Loading branch information
changm authored and Google-ML-Automation committed Dec 13, 2024
1 parent a8a0a7e commit d85be10
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion xla/tools/multihost_hlo_runner/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,10 @@ cc_library(
"//xla/pjrt/distributed:client",
"//xla/pjrt/distributed:key_value_store_interface",
"//xla/pjrt/distributed:service",
"//xla/pjrt/gpu:se_gpu_pjrt_client",
"//xla/pjrt/plugin/xla_cpu:cpu_client_options",
"//xla/pjrt/plugin/xla_cpu:xla_cpu_pjrt_client",
"//xla/pjrt/plugin/xla_gpu:xla_gpu_client_options",
"//xla/pjrt/plugin/xla_gpu:xla_gpu_pjrt_client",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
Expand Down
4 changes: 2 additions & 2 deletions xla/tools/multihost_hlo_runner/create_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ limitations under the License.
#include "xla/pjrt/distributed/client.h"
#include "xla/pjrt/distributed/distributed.h"
#include "xla/pjrt/distributed/service.h"
#include "xla/pjrt/gpu/se_gpu_pjrt_client.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/pjrt/pjrt_compiler.h"
#include "xla/pjrt/plugin/xla_cpu/cpu_client_options.h"
#include "xla/pjrt/plugin/xla_cpu/xla_cpu_pjrt_client.h"
#include "xla/pjrt/plugin/xla_gpu/xla_gpu_client_options.h"
#include "xla/pjrt/plugin/xla_gpu/xla_gpu_pjrt_client.h"
#include "xla/status_macros.h"
#include "xla/xla.pb.h"
#include "tsl/platform/status.h"
Expand Down Expand Up @@ -113,7 +113,7 @@ absl::StatusOr<std::unique_ptr<PjRtClient>> CreateGpuClient(
return absl::InvalidArgumentError(
"Node id is expected to be in range [0, num_nodes)");
}
return GetStreamExecutorGpuClient(options);
return xla::GetXlaPjrtGpuClient(options);
}

absl::StatusOr<std::unique_ptr<PjRtClient>> CreateMockGpuClient(int num_nodes) {
Expand Down

0 comments on commit d85be10

Please sign in to comment.