From a7c38b739286fa2cb0a13d1f31d85f5a195bc534 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Tue, 3 Dec 2024 10:39:41 -0800 Subject: [PATCH] [xla:collectives] NFC: Remove NcclCliqueKey alias PiperOrigin-RevId: 702392147 --- xla/backends/gpu/collectives/BUILD | 3 +- xla/pjrt/gpu/BUILD | 2 +- xla/pjrt/gpu/nccl_id_store.cc | 4 +- xla/pjrt/gpu/nccl_id_store.h | 6 +-- xla/service/gpu/BUILD | 6 +-- xla/service/gpu/fusions/BUILD | 2 +- xla/service/gpu/fusions/custom.cc | 2 +- xla/service/gpu/gpu_executable.cc | 8 ++-- xla/service/gpu/ir_emitter_unnested.cc | 2 +- xla/service/gpu/runtime/BUILD | 48 +++++-------------- xla/service/gpu/runtime/command_buffer_cmd.cc | 10 ++-- xla/service/gpu/runtime/command_buffer_cmd.h | 2 +- .../gpu/runtime/nccl_all_to_all_thunk.cc | 2 +- .../gpu/runtime/nccl_all_to_all_thunk.h | 2 +- xla/service/gpu/runtime/nccl_api.cc | 2 +- xla/service/gpu/runtime/nccl_api_stub.cc | 1 - xla/service/gpu/runtime/nccl_clique.cc | 18 +++---- xla/service/gpu/runtime/nccl_clique.h | 18 +++---- xla/service/gpu/runtime/nccl_clique_key.h | 32 ------------- .../gpu/runtime/nccl_collective_thunk.cc | 34 ++++++------- .../gpu/runtime/nccl_collective_thunk.h | 4 +- .../gpu/runtime/nccl_p2p_thunk_common.cc | 2 +- .../gpu/runtime/nccl_p2p_thunk_common.h | 2 +- xla/service/gpu/runtime/nccl_recv_thunk.h | 2 +- xla/service/gpu/runtime/nccl_send_thunk.h | 2 +- xla/service/gpu/runtime/thunk.cc | 8 ++-- xla/service/gpu/runtime/thunk.h | 10 ++-- 27 files changed, 88 insertions(+), 146 deletions(-) delete mode 100644 xla/service/gpu/runtime/nccl_clique_key.h diff --git a/xla/backends/gpu/collectives/BUILD b/xla/backends/gpu/collectives/BUILD index 98a3261148437..868920737faef 100644 --- a/xla/backends/gpu/collectives/BUILD +++ b/xla/backends/gpu/collectives/BUILD @@ -1,12 +1,13 @@ load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured") load("//xla:xla.bzl", "xla_cc_test") load("//xla/stream_executor:build_defs.bzl", "if_gpu_is_configured") +load("//xla/tsl:tsl.bzl", "internal_visibility") load("//xla/tsl/platform:rules_cc.bzl", "cc_library") load("//xla/tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = [":friends"], + default_visibility = internal_visibility([":friends"]), licenses = ["notice"], ) diff --git a/xla/pjrt/gpu/BUILD b/xla/pjrt/gpu/BUILD index 126a2879ca7e5..d9d0e5832a52c 100644 --- a/xla/pjrt/gpu/BUILD +++ b/xla/pjrt/gpu/BUILD @@ -219,12 +219,12 @@ cc_library( deps = [ "//xla:status_macros", "//xla:util", + "//xla/backends/gpu/collectives:gpu_clique_key", "//xla/core/collectives:clique_id", "//xla/core/collectives:clique_key", "//xla/pjrt/distributed:key_value_store_interface", "//xla/service:global_device_id", "//xla/service/gpu/runtime:nccl_api", - "//xla/service/gpu/runtime:nccl_clique_key", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status:statusor", diff --git a/xla/pjrt/gpu/nccl_id_store.cc b/xla/pjrt/gpu/nccl_id_store.cc index 3c0c586b3567c..54f1ff0c2f3ec 100644 --- a/xla/pjrt/gpu/nccl_id_store.cc +++ b/xla/pjrt/gpu/nccl_id_store.cc @@ -21,10 +21,10 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/synchronization/mutex.h" #include "absl/time/time.h" +#include "xla/backends/gpu/collectives/gpu_clique_key.h" #include "xla/core/collectives/clique_id.h" #include "xla/core/collectives/clique_key.h" #include "xla/service/gpu/runtime/nccl_api.h" -#include "xla/service/gpu/runtime/nccl_clique_key.h" #include "xla/status_macros.h" #include "xla/util.h" #include "tsl/platform/casts.h" @@ -34,7 +34,7 @@ limitations under the License. namespace xla { absl::StatusOr NcclIdStore::GetNcclUniqueId(const CliqueKey& key) { - auto* gpu_key = tsl::down_cast(&key); + auto* gpu_key = tsl::down_cast(&key); if (gpu_key == nullptr) { return InvalidArgument("Expected GPU clique key"); } diff --git a/xla/pjrt/gpu/nccl_id_store.h b/xla/pjrt/gpu/nccl_id_store.h index d235a58c242a1..fe8b060cb946a 100644 --- a/xla/pjrt/gpu/nccl_id_store.h +++ b/xla/pjrt/gpu/nccl_id_store.h @@ -23,15 +23,15 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/status/statusor.h" #include "absl/synchronization/mutex.h" +#include "xla/backends/gpu/collectives/gpu_clique_key.h" #include "xla/core/collectives/clique_id.h" #include "xla/core/collectives/clique_key.h" #include "xla/pjrt/distributed/key_value_store_interface.h" #include "xla/service/global_device_id.h" -#include "xla/service/gpu/runtime/nccl_clique_key.h" namespace xla { -// A table mapping NcclCliqueKeys to NcclCliqueIds. In a distributed setup the +// A table mapping GpuCliqueKeys to CliqueIds. In a distributed setup the // table of NCCL IDs is kept on the master node (node 0). The node of the first // participating device will create the unique id. class NcclIdStore { @@ -51,7 +51,7 @@ class NcclIdStore { const std::shared_ptr kv_store_; absl::Mutex mu_; - absl::flat_hash_map cache_ ABSL_GUARDED_BY(mu_); + absl::flat_hash_map cache_ ABSL_GUARDED_BY(mu_); }; } // namespace xla diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index fe31bf69f6c57..b814f4031fca1 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -97,10 +97,10 @@ cc_library( visibility = ["//visibility:public"], deps = [ "//xla:executable_run_options", + "//xla/backends/gpu/collectives:gpu_clique_key", "//xla/core/collectives:clique_id", "//xla/core/collectives:clique_key", "//xla/service:global_device_id", - "//xla/service/gpu/runtime:nccl_clique_key", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", ], @@ -339,6 +339,7 @@ cc_library( "//xla:status_macros", "//xla:util", "//xla:xla_data_proto_cc", + "//xla/backends/gpu/collectives:gpu_clique_key", "//xla/ffi:attribute_map", "//xla/ffi:ffi_api", "//xla/ffi/api:c_api", @@ -375,7 +376,6 @@ cc_library( "//xla/service/gpu/runtime:nccl_all_to_all_thunk", "//xla/service/gpu/runtime:nccl_api", "//xla/service/gpu/runtime:nccl_clique", - "//xla/service/gpu/runtime:nccl_clique_key", "//xla/service/gpu/runtime:nccl_collective_broadcast_thunk", "//xla/service/gpu/runtime:nccl_collective_permute_thunk", "//xla/service/gpu/runtime:nccl_collective_thunk", @@ -575,6 +575,7 @@ cc_library( "//xla:shape_util", "//xla:status_macros", "//xla:util", + "//xla/backends/gpu/collectives:gpu_clique_key", "//xla/hlo/ir:hlo", "//xla/service:buffer_assignment", "//xla/service:executable", @@ -587,7 +588,6 @@ cc_library( "//xla/service/gpu/runtime:annotation", "//xla/service/gpu/runtime:for_all_thunks", "//xla/service/gpu/runtime:nccl_clique", - "//xla/service/gpu/runtime:nccl_clique_key", "//xla/service/gpu/runtime:sequential_thunk", "//xla/service/gpu/runtime:thunk", "//xla/stream_executor:device_description", diff --git a/xla/service/gpu/fusions/BUILD b/xla/service/gpu/fusions/BUILD index 3b2da268da838..a0f3e5fb46faa 100644 --- a/xla/service/gpu/fusions/BUILD +++ b/xla/service/gpu/fusions/BUILD @@ -69,6 +69,7 @@ cc_library( "//xla:status_macros", "//xla:util", "//xla:xla_data_proto_cc", + "//xla/backends/gpu/collectives:gpu_clique_key", "//xla/ffi:attribute_map", "//xla/ffi:ffi_api", "//xla/hlo/ir:hlo", @@ -95,7 +96,6 @@ cc_library( "//xla/service/gpu/runtime:kernel_thunk", "//xla/service/gpu/runtime:nccl_all_reduce_thunk", "//xla/service/gpu/runtime:nccl_api", - "//xla/service/gpu/runtime:nccl_clique_key", "//xla/service/gpu/runtime:nccl_collective_thunk", "//xla/service/gpu/runtime:thunk", "@com_google_absl//absl/algorithm:container", diff --git a/xla/service/gpu/fusions/custom.cc b/xla/service/gpu/fusions/custom.cc index ba827d77d444a..1eeb4b9b07612 100644 --- a/xla/service/gpu/fusions/custom.cc +++ b/xla/service/gpu/fusions/custom.cc @@ -35,6 +35,7 @@ limitations under the License. #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/Support/LLVM.h" +#include "xla/backends/gpu/collectives/gpu_clique_key.h" #include "xla/ffi/attribute_map.h" #include "xla/ffi/ffi_api.h" #include "xla/hlo/ir/hlo_casting_utils.h" @@ -64,7 +65,6 @@ limitations under the License. #include "xla/service/gpu/runtime/kernel_thunk.h" #include "xla/service/gpu/runtime/nccl_all_reduce_thunk.h" #include "xla/service/gpu/runtime/nccl_api.h" -#include "xla/service/gpu/runtime/nccl_clique_key.h" #include "xla/service/gpu/runtime/nccl_collective_thunk.h" #include "xla/service/gpu/runtime/thunk.h" #include "xla/service/gpu/stream_executor_util.h" diff --git a/xla/service/gpu/gpu_executable.cc b/xla/service/gpu/gpu_executable.cc index 3064a9f890242..1fa2a294acc1f 100644 --- a/xla/service/gpu/gpu_executable.cc +++ b/xla/service/gpu/gpu_executable.cc @@ -37,6 +37,7 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "absl/time/time.h" #include "absl/types/span.h" +#include "xla/backends/gpu/collectives/gpu_clique_key.h" #include "xla/executable_run_options.h" #include "xla/hlo/ir/hlo_input_output_alias_config.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -51,7 +52,6 @@ limitations under the License. #include "xla/service/gpu/runtime/annotation.h" #include "xla/service/gpu/runtime/for_all_thunks.h" #include "xla/service/gpu/runtime/nccl_clique.h" -#include "xla/service/gpu/runtime/nccl_clique_key.h" #include "xla/service/gpu/runtime/sequential_thunk.h" #include "xla/service/gpu/runtime/thunk.h" #include "xla/service/gpu/stream_executor_util.h" @@ -187,7 +187,7 @@ namespace { // Shared resources required for thunk initialization and execution. class ResourceRequests : public Thunk::ResourceRequests { public: - absl::Status AddClique(const NcclCliqueKey& clique_key, + absl::Status AddClique(const GpuCliqueKey& clique_key, int32_t num_local_participants) final { VLOG(5) << "Add collective clique request: " << clique_key.ToString() << "; num_local_participants: " << num_local_participants; @@ -286,7 +286,7 @@ class ResourceRequests : public Thunk::ResourceRequests { private: struct CliqueRequest { - NcclCliqueKey key; + GpuCliqueKey key; int64_t num_local_participants; int64_t id; }; @@ -326,7 +326,7 @@ class ResourceRequests : public Thunk::ResourceRequests { return cliques; } - absl::flat_hash_map cliques_; + absl::flat_hash_map cliques_; }; absl::Status MaybeSyncAndProfile(const ServiceExecutableRunOptions* run_options, diff --git a/xla/service/gpu/ir_emitter_unnested.cc b/xla/service/gpu/ir_emitter_unnested.cc index 23fd64c4eb087..0201dbb4e008d 100644 --- a/xla/service/gpu/ir_emitter_unnested.cc +++ b/xla/service/gpu/ir_emitter_unnested.cc @@ -68,6 +68,7 @@ limitations under the License. #include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Export.h" +#include "xla/backends/gpu/collectives/gpu_clique_key.h" #include "xla/ffi/api/c_api.h" #include "xla/ffi/attribute_map.h" #include "xla/ffi/ffi_api.h" @@ -126,7 +127,6 @@ limitations under the License. #include "xla/service/gpu/runtime/nccl_all_to_all_thunk.h" #include "xla/service/gpu/runtime/nccl_api.h" #include "xla/service/gpu/runtime/nccl_clique.h" -#include "xla/service/gpu/runtime/nccl_clique_key.h" #include "xla/service/gpu/runtime/nccl_collective_broadcast_thunk.h" #include "xla/service/gpu/runtime/nccl_collective_permute_thunk.h" #include "xla/service/gpu/runtime/nccl_collective_thunk.h" diff --git a/xla/service/gpu/runtime/BUILD b/xla/service/gpu/runtime/BUILD index 62fa7f6a373d2..20d40c735e560 100644 --- a/xla/service/gpu/runtime/BUILD +++ b/xla/service/gpu/runtime/BUILD @@ -61,7 +61,6 @@ cc_library( ":nccl_all_reduce_thunk", ":nccl_all_to_all_thunk", ":nccl_api", - ":nccl_clique_key", ":nccl_collective_broadcast_thunk", ":nccl_collective_thunk", ":thunk", @@ -72,6 +71,7 @@ cc_library( "//xla:status_macros", "//xla:types", "//xla:util", + "//xla/backends/gpu/collectives:gpu_clique_key", "//xla/backends/gpu/collectives:gpu_collectives", "//xla/core/collectives:communicator", "//xla/ffi:call_frame", @@ -212,10 +212,10 @@ cc_library( hdrs = ["nccl_api.h"], compatible_with = get_compatible_with_portable(), deps = [ - ":nccl_clique_key", "//xla:shape_util", "//xla:util", "//xla:xla_data_proto_cc", + "//xla/backends/gpu/collectives:gpu_clique_key", "//xla/backends/gpu/collectives:gpu_collectives", "//xla/backends/gpu/collectives:nccl_communicator", "//xla/core/collectives:clique_id", @@ -257,9 +257,9 @@ cc_library( hdrs = ["nccl_api.h"], compatible_with = get_compatible_with_portable(), deps = [ - ":nccl_clique_key", "//xla:shape_util", "//xla:xla_data_proto_cc", + "//xla/backends/gpu/collectives:gpu_clique_key", "//xla/backends/gpu/collectives:gpu_collectives", "//xla/core/collectives:clique_id", "//xla/core/collectives:communicator", @@ -282,12 +282,12 @@ cc_library( hdrs = ["nccl_clique.h"], deps = [ ":nccl_api", - ":nccl_clique_key", "//xla:debug_options_flags", "//xla:executable_run_options", "//xla:status_macros", "//xla:types", "//xla:util", + "//xla/backends/gpu/collectives:gpu_clique_key", "//xla/core/collectives:clique_id", "//xla/core/collectives:clique_key", "//xla/core/collectives:communicator", @@ -317,32 +317,6 @@ cc_library( ], ) -cc_library( - name = "nccl_clique_key", - hdrs = ["nccl_clique_key.h"], - compatible_with = get_compatible_with_portable(), - deps = [ - "//xla/backends/gpu/collectives:gpu_clique_key", - "//xla/backends/gpu/collectives:gpu_collectives", - "//xla/core/collectives:clique_id", - "//xla/core/collectives:clique_key", - "//xla/core/collectives:rank_id", - "//xla/service:global_device_id", - "//xla/tsl/lib/gtl:int_type", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/crc:crc32c", - "@com_google_absl//absl/hash", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@tsl//tsl/platform:casts", - "@tsl//tsl/platform:logging", - ], -) - #===-------------------------------------------------------------------------------------------===// # XLA Thunks Runtime #===-------------------------------------------------------------------------------------------===// @@ -868,11 +842,11 @@ cc_library( hdrs = ["nccl_all_to_all_thunk.h"], deps = [ ":nccl_api", - ":nccl_clique_key", ":nccl_collective_thunk", ":thunk", "//xla:shape_util", "//xla:status_macros", + "//xla/backends/gpu/collectives:gpu_clique_key", "//xla/backends/gpu/collectives:gpu_collectives", "//xla/core/collectives:communicator", "//xla/hlo/ir:hlo", @@ -946,12 +920,12 @@ cc_library( deps = [ ":nccl_api", ":nccl_clique", - ":nccl_clique_key", ":thunk", "//xla:debug_options_flags", "//xla:shape_util", "//xla:util", "//xla:xla_data_proto_cc", + "//xla/backends/gpu/collectives:gpu_clique_key", "//xla/backends/gpu/collectives:gpu_collectives", "//xla/core/collectives:communicator", "//xla/core/collectives:rank_id", @@ -995,12 +969,12 @@ cc_library( srcs = ["nccl_p2p_thunk_common.cc"], hdrs = ["nccl_p2p_thunk_common.h"], deps = [ - ":nccl_clique_key", ":nccl_collective_thunk", "//xla:executable_run_options", "//xla:shape_util", "//xla:status_macros", "//xla:xla_data_proto_cc", + "//xla/backends/gpu/collectives:gpu_clique_key", "//xla/core/collectives:communicator", "//xla/hlo/ir:hlo", "//xla/hlo/parser:hlo_parser", @@ -1023,11 +997,11 @@ cc_library( hdrs = ["nccl_recv_thunk.h"], deps = [ ":nccl_api", - ":nccl_clique_key", ":nccl_collective_thunk", ":nccl_p2p_thunk_common", ":thunk", "//xla:status_macros", + "//xla/backends/gpu/collectives:gpu_clique_key", "//xla/core/collectives:communicator", "//xla/hlo/ir:hlo", "//xla/service:collective_ops_utils", @@ -1049,11 +1023,11 @@ cc_library( hdrs = ["nccl_send_thunk.h"], deps = [ ":nccl_api", - ":nccl_clique_key", ":nccl_collective_thunk", ":nccl_p2p_thunk_common", ":thunk", "//xla:status_macros", + "//xla/backends/gpu/collectives:gpu_clique_key", "//xla/core/collectives:communicator", "//xla/hlo/ir:hlo", "//xla/service:collective_ops_utils", @@ -1075,9 +1049,9 @@ cc_library( hdrs = ["nccl_group_thunk.h"], deps = [ ":nccl_api", - ":nccl_clique_key", ":nccl_collective_thunk", ":thunk", + "//xla/backends/gpu/collectives:gpu_clique_key", "//xla/hlo/ir:hlo", "//xla/stream_executor:stream", "@com_google_absl//absl/status", @@ -1189,8 +1163,8 @@ cc_library( hdrs = ["thunk.h"], deps = [ ":nccl_clique", - ":nccl_clique_key", "//xla:executable_run_options", + "//xla/backends/gpu/collectives:gpu_clique_key", "//xla/core/collectives:communicator", "//xla/core/collectives:rank_id", "//xla/ffi:execution_context", diff --git a/xla/service/gpu/runtime/command_buffer_cmd.cc b/xla/service/gpu/runtime/command_buffer_cmd.cc index 96e3abf70f12b..f08ba05d80c1a 100644 --- a/xla/service/gpu/runtime/command_buffer_cmd.cc +++ b/xla/service/gpu/runtime/command_buffer_cmd.cc @@ -41,6 +41,7 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "llvm/ADT/STLExtras.h" +#include "xla/backends/gpu/collectives/gpu_clique_key.h" #include "xla/core/collectives/communicator.h" #include "xla/debug_options_flags.h" #include "xla/executable_run_options.h" @@ -60,7 +61,6 @@ limitations under the License. #include "xla/service/gpu/runtime/nccl_all_reduce_thunk.h" #include "xla/service/gpu/runtime/nccl_all_to_all_thunk.h" #include "xla/service/gpu/runtime/nccl_api.h" -#include "xla/service/gpu/runtime/nccl_clique_key.h" #include "xla/service/gpu/runtime/nccl_collective_broadcast_thunk.h" #include "xla/service/gpu/runtime/nccl_collective_thunk.h" #include "xla/service/gpu/runtime/thunk.h" @@ -1614,10 +1614,10 @@ absl::Status CollectiveCmd::Prepare( const Thunk::PrepareParams& params, Thunk::ResourceRequests& resource_requests) { TF_ASSIGN_OR_RETURN( - NcclCliqueKey clique_key, - GetNcclCliqueKey(*params.collective_params, config().replica_groups, - config().group_mode, nccl_stream_id(), - GetAsyncStreamKind())); + GpuCliqueKey clique_key, + GetGpuCliqueKey(*params.collective_params, config().replica_groups, + config().group_mode, nccl_stream_id(), + GetAsyncStreamKind())); TF_ASSIGN_OR_RETURN( size_t num_local_participants, GetNumLocalParticipants(*params.collective_params, diff --git a/xla/service/gpu/runtime/command_buffer_cmd.h b/xla/service/gpu/runtime/command_buffer_cmd.h index a0f20bd69ac87..a283576ace6a5 100644 --- a/xla/service/gpu/runtime/command_buffer_cmd.h +++ b/xla/service/gpu/runtime/command_buffer_cmd.h @@ -37,6 +37,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" +#include "xla/backends/gpu/collectives/gpu_clique_key.h" #include "xla/backends/gpu/collectives/gpu_collectives.h" #include "xla/ffi/api/c_api.h" #include "xla/hlo/ir/hlo_computation.h" @@ -49,7 +50,6 @@ limitations under the License. #include "xla/service/gpu/runtime/custom_call_thunk.h" #include "xla/service/gpu/runtime/dynamic_slice_thunk.h" #include "xla/service/gpu/runtime/nccl_api.h" -#include "xla/service/gpu/runtime/nccl_clique_key.h" #include "xla/service/gpu/runtime/nccl_collective_thunk.h" #include "xla/service/gpu/runtime/thunk.h" #include "xla/shape.h" diff --git a/xla/service/gpu/runtime/nccl_all_to_all_thunk.cc b/xla/service/gpu/runtime/nccl_all_to_all_thunk.cc index 0368ac69c22ff..cc42d145a67ac 100644 --- a/xla/service/gpu/runtime/nccl_all_to_all_thunk.cc +++ b/xla/service/gpu/runtime/nccl_all_to_all_thunk.cc @@ -27,13 +27,13 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/substitute.h" #include "absl/synchronization/mutex.h" +#include "xla/backends/gpu/collectives/gpu_clique_key.h" #include "xla/backends/gpu/collectives/gpu_collectives.h" #include "xla/core/collectives/communicator.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/gpu/runtime/nccl_api.h" -#include "xla/service/gpu/runtime/nccl_clique_key.h" #include "xla/service/gpu/runtime/nccl_collective_thunk.h" #include "xla/service/gpu/runtime/thunk.h" #include "xla/shape.h" diff --git a/xla/service/gpu/runtime/nccl_all_to_all_thunk.h b/xla/service/gpu/runtime/nccl_all_to_all_thunk.h index 43a346310d295..e02bf7c9370da 100644 --- a/xla/service/gpu/runtime/nccl_all_to_all_thunk.h +++ b/xla/service/gpu/runtime/nccl_all_to_all_thunk.h @@ -25,11 +25,11 @@ limitations under the License. #include "absl/status/status.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" +#include "xla/backends/gpu/collectives/gpu_clique_key.h" #include "xla/core/collectives/communicator.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/gpu/runtime/nccl_api.h" -#include "xla/service/gpu/runtime/nccl_clique_key.h" #include "xla/service/gpu/runtime/nccl_collective_thunk.h" #include "xla/stream_executor/stream.h" diff --git a/xla/service/gpu/runtime/nccl_api.cc b/xla/service/gpu/runtime/nccl_api.cc index de5e89b04727f..33e5f1cd6cd3e 100644 --- a/xla/service/gpu/runtime/nccl_api.cc +++ b/xla/service/gpu/runtime/nccl_api.cc @@ -31,13 +31,13 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/types/span.h" +#include "xla/backends/gpu/collectives/gpu_clique_key.h" #include "xla/backends/gpu/collectives/nccl_communicator.h" #include "xla/core/collectives/clique_id.h" #include "xla/core/collectives/communicator.h" #include "xla/core/collectives/rank_id.h" #include "xla/primitive_util.h" #include "xla/service/collective_ops_utils.h" -#include "xla/service/gpu/runtime/nccl_clique_key.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/gpu/gpu_stream.h" diff --git a/xla/service/gpu/runtime/nccl_api_stub.cc b/xla/service/gpu/runtime/nccl_api_stub.cc index 611a0320b1fcc..84043b42c1242 100644 --- a/xla/service/gpu/runtime/nccl_api_stub.cc +++ b/xla/service/gpu/runtime/nccl_api_stub.cc @@ -27,7 +27,6 @@ limitations under the License. #include "xla/core/collectives/rank_id.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/gpu/runtime/nccl_api.h" -#include "xla/service/gpu/runtime/nccl_clique_key.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/stream.h" diff --git a/xla/service/gpu/runtime/nccl_clique.cc b/xla/service/gpu/runtime/nccl_clique.cc index 3e575cb21c872..de3af2bcd5e34 100644 --- a/xla/service/gpu/runtime/nccl_clique.cc +++ b/xla/service/gpu/runtime/nccl_clique.cc @@ -38,6 +38,7 @@ limitations under the License. #include "absl/time/clock.h" #include "absl/time/time.h" #include "absl/types/span.h" +#include "xla/backends/gpu/collectives/gpu_clique_key.h" #include "xla/core/collectives/clique_id.h" #include "xla/core/collectives/clique_key.h" #include "xla/core/collectives/communicator.h" @@ -47,7 +48,6 @@ limitations under the License. #include "xla/service/global_device_id.h" #include "xla/service/gpu/gpu_executable_run_options.h" #include "xla/service/gpu/runtime/nccl_api.h" -#include "xla/service/gpu/runtime/nccl_clique_key.h" #include "xla/service/lockable.h" #include "xla/service/rendezvous.h" #include "xla/status_macros.h" @@ -112,7 +112,7 @@ static bool TerminateOnNcclError() { //===----------------------------------------------------------------------===// NcclCliqueCommunicators::NcclCliqueCommunicators( - NcclCliqueKey clique_key, std::optional clique_id, + GpuCliqueKey clique_key, std::optional clique_id, absl::btree_map> communicators) : clique_key_(std::move(clique_key)), clique_id_(std::move(clique_id)), @@ -158,7 +158,7 @@ namespace { // Container for initialized and ready to use local (in-process) NCCL cliques. struct NcclCliques { absl::Mutex mu; - absl::node_hash_map map ABSL_GUARDED_BY(mu); + absl::node_hash_map map ABSL_GUARDED_BY(mu); }; } // namespace @@ -186,7 +186,7 @@ static absl::Status CheckComm(Communicator* comm) { } // Runs async check on all communicators in a clique. -static void CheckClique(const NcclCliqueKey& clique_key, +static void CheckClique(const GpuCliqueKey& clique_key, NcclClique& lockable_clique) { if (TerminateOnNcclError()) { absl::Status status = lockable_clique.CheckAsyncErrors(); @@ -253,7 +253,7 @@ static auto DeviceRanksToString(absl::Span ranks) { // a lock that gives an access to initialized clique (access is shared between // all participating ranks that own a shared pointer). static absl::StatusOr> InitializeNcclClique( - se::StreamExecutor* device, RunId run_id, NcclCliqueKey clique_key, + se::StreamExecutor* device, RunId run_id, GpuCliqueKey clique_key, const CliqueIdCallback& clique_id_callback, int32_t num_local_participants, RankId rank, NcclApi::Config& config) { int nranks = clique_key.devices().size(); @@ -359,7 +359,7 @@ static absl::StatusOr> InitializeNcclClique( // Computes a unique NCCL communicator split color from a clique key. We use a // deterministic hash function to guarantee that all participating processes get // the same color value for a clique. -static int32_t GetCommSplitColor(const NcclCliqueKey& clique_key) { +static int32_t GetCommSplitColor(const GpuCliqueKey& clique_key) { std::vector global_device_ids; global_device_ids.reserve(clique_key.devices().size()); @@ -377,11 +377,11 @@ static int32_t GetCommSplitColor(const NcclCliqueKey& clique_key) { // `parent_clique` clique (access is shared between all participating ranks that // own a shared pointer). static absl::StatusOr> InitializeNcclClique( - se::StreamExecutor* device, RunId run_id, NcclCliqueKey clique_key, + se::StreamExecutor* device, RunId run_id, GpuCliqueKey clique_key, std::shared_ptr parent_clique, int32_t num_local_participants, RankId rank, NcclApi::Config& config) { // Find our rank in the parent clique. - const NcclCliqueKey& parent_clique_key = (*parent_clique)->clique_key(); + const GpuCliqueKey& parent_clique_key = (*parent_clique)->clique_key(); RankId parent_rank = *parent_clique_key.rank(clique_key.devices()[rank.value()]); @@ -495,7 +495,7 @@ static absl::StatusOr> InitializeNcclClique( using AcquiredCliquesMap = NcclClique::AcquiredCliquesMap; absl::StatusOr> AcquireNcclClique( - se::StreamExecutor* device, RunId run_id, NcclCliqueKey clique_key, + se::StreamExecutor* device, RunId run_id, GpuCliqueKey clique_key, const CliqueIdCallback& clique_id_callback, RankId rank, size_t num_local_participants, const AcquiredCliquesMap& acquired_cliques, int64_t max_nchannels) { diff --git a/xla/service/gpu/runtime/nccl_clique.h b/xla/service/gpu/runtime/nccl_clique.h index 8fde5d0eff3fa..4ed4f2bdd4008 100644 --- a/xla/service/gpu/runtime/nccl_clique.h +++ b/xla/service/gpu/runtime/nccl_clique.h @@ -29,12 +29,12 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" +#include "xla/backends/gpu/collectives/gpu_clique_key.h" #include "xla/core/collectives/clique_id.h" #include "xla/core/collectives/communicator.h" #include "xla/core/collectives/rank_id.h" #include "xla/executable_run_options.h" #include "xla/service/gpu/gpu_executable_run_options.h" -#include "xla/service/gpu/runtime/nccl_clique_key.h" #include "xla/service/lockable.h" #include "xla/stream_executor/stream_executor.h" #include "xla/types.h" // IWYU pragma: keep @@ -55,7 +55,7 @@ namespace xla::gpu { // participating devices and properties of collective operations launched on // them, e.g. mixing NCCL operations launched from CUDA graphs with regularly // launched operations is prone to dead locks, and we keep them separate. See -// NcclCliqueKey for details. +// GpuCliqueKey for details. // // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/communicators.html#using-multiple-nccl-communicators-concurrently @@ -96,7 +96,7 @@ class NcclCliqueCommunicators { }; NcclCliqueCommunicators( - NcclCliqueKey clique_key, std::optional clique_id, + GpuCliqueKey clique_key, std::optional clique_id, absl::btree_map> communicators); // Returns a NCCL communicator for a given rank if it's in a clique. @@ -109,7 +109,7 @@ class NcclCliqueCommunicators { // Calls `fn` for each communicator in the clique. void ForEachComm(absl::FunctionRef fn); - const NcclCliqueKey& clique_key() const { return clique_key_; } + const GpuCliqueKey& clique_key() const { return clique_key_; } const std::optional& clique_id() const { return clique_id_; } size_t num_communicators() const { return communicators_.size(); } @@ -118,7 +118,7 @@ class NcclCliqueCommunicators { AsyncErrorChecker GetChecker() { return AsyncErrorChecker(*this); } private: - NcclCliqueKey clique_key_; + GpuCliqueKey clique_key_; std::optional clique_id_; // TODO(ezhulenev): Switch this map to GlobalDeviceId key. @@ -136,8 +136,8 @@ class NcclClique : public Lockable { // We keep acquired cliques in a sorted container to guarantee that all // participants iterate over cliques in the same order. using AcquiredCliquesMap = - absl::btree_map, - std::greater>; + absl::btree_map, + std::greater>; // Construct the lockable clique. // Note that async errors can be checked without acquiring the lock. @@ -145,7 +145,7 @@ class NcclClique : public Lockable { // error checks, the constructor intentionally leaks the reference // to the communicators from an acquired lock. NcclClique( - NcclCliqueKey clique_key, std::optional clique_id, + GpuCliqueKey clique_key, std::optional clique_id, absl::btree_map> communicators) : Lockable(std::move(clique_key), clique_id, std::move(communicators)), async_error_checker_(Acquire()->GetChecker()) {} @@ -169,7 +169,7 @@ class NcclClique : public Lockable { // created communicators or maybe created by splitting of the already acquired // cliques. absl::StatusOr> AcquireNcclClique( - se::StreamExecutor* device, RunId run_id, NcclCliqueKey clique_key, + se::StreamExecutor* device, RunId run_id, GpuCliqueKey clique_key, const CliqueIdCallback& clique_id_callback, RankId rank, size_t num_local_participants, const NcclClique::AcquiredCliquesMap& acquired_cliques, diff --git a/xla/service/gpu/runtime/nccl_clique_key.h b/xla/service/gpu/runtime/nccl_clique_key.h deleted file mode 100644 index c46395647548e..0000000000000 --- a/xla/service/gpu/runtime/nccl_clique_key.h +++ /dev/null @@ -1,32 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_RUNTIME_NCCL_CLIQUE_KEY_H_ -#define XLA_SERVICE_GPU_RUNTIME_NCCL_CLIQUE_KEY_H_ - -#include "xla/backends/gpu/collectives/gpu_clique_key.h" - -namespace xla::gpu { - -//===----------------------------------------------------------------------===// -// NcclCliqueKey -//===----------------------------------------------------------------------===// - -// TODO(b/380457503): Delete this alias. -using NcclCliqueKey = GpuCliqueKey; - -} // namespace xla::gpu - -#endif // XLA_SERVICE_GPU_RUNTIME_NCCL_CLIQUE_KEY_H_ diff --git a/xla/service/gpu/runtime/nccl_collective_thunk.cc b/xla/service/gpu/runtime/nccl_collective_thunk.cc index 6aa61d4fdb31c..5162d1b04f269 100644 --- a/xla/service/gpu/runtime/nccl_collective_thunk.cc +++ b/xla/service/gpu/runtime/nccl_collective_thunk.cc @@ -34,6 +34,7 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" #include "absl/time/time.h" +#include "xla/backends/gpu/collectives/gpu_clique_key.h" #include "xla/backends/gpu/collectives/gpu_collectives.h" #include "xla/core/collectives/communicator.h" #include "xla/core/collectives/rank_id.h" @@ -47,7 +48,6 @@ limitations under the License. #include "xla/service/gpu/buffer_allocations.h" #include "xla/service/gpu/runtime/nccl_api.h" #include "xla/service/gpu/runtime/nccl_clique.h" -#include "xla/service/gpu/runtime/nccl_clique_key.h" #include "xla/service/gpu/runtime/thunk.h" #include "xla/service/rendezvous.h" #include "xla/shape.h" @@ -215,7 +215,7 @@ NcclCollectiveThunk::NcclCollectiveThunk(Kind kind, ThunkInfo thunk_info, nccl_api_(nccl_api), async_events_(is_sync ? nullptr : new AsyncEvents()) {} -absl::StatusOr GetNcclCliqueKey( +absl::StatusOr GetGpuCliqueKey( const Thunk::CollectiveExecuteParams& params, const std::vector& replica_groups, CollectiveOpGroupMode group_mode, CollectiveStreamId stream_id, @@ -248,9 +248,9 @@ absl::StatusOr GetNcclCliqueKey( static const bool enable_per_stream_comms = xla::GetDebugOptionsFromFlags().xla_gpu_enable_nccl_per_stream_comms(); - return NcclCliqueKey(std::move(participants), - enable_per_stream_comms ? stream_id : kNoStreamId, - stream_kind, std::move(participant_groups)); + return GpuCliqueKey(std::move(participants), + enable_per_stream_comms ? stream_id : kNoStreamId, + stream_kind, std::move(participant_groups)); } absl::StatusOr GetNcclComm( @@ -259,9 +259,9 @@ absl::StatusOr GetNcclComm( const std::vector& replica_groups, CollectiveOpGroupMode group_mode, CollectiveStreamId stream_id, AsyncStreamKind stream_kind) { - TF_ASSIGN_OR_RETURN(NcclCliqueKey clique_key, - GetNcclCliqueKey(params, replica_groups, group_mode, - stream_id, stream_kind)); + TF_ASSIGN_OR_RETURN(GpuCliqueKey clique_key, + GetGpuCliqueKey(params, replica_groups, group_mode, + stream_id, stream_kind)); std::optional rank = clique_key.rank(params.global_device_id); TF_ASSIGN_OR_RETURN(bool is_local, @@ -379,10 +379,10 @@ absl::StatusOr NcclCollectiveThunk::AsyncEvents::GetEvent( absl::Status NcclCollectiveThunk::Prepare(const PrepareParams& params, ResourceRequests& resource_requests) { TF_ASSIGN_OR_RETURN( - NcclCliqueKey clique_key, - GetNcclCliqueKey(*params.collective_params, config().replica_groups, - config().group_mode, nccl_stream_id(), - GetAsyncStreamKind())); + GpuCliqueKey clique_key, + GetGpuCliqueKey(*params.collective_params, config().replica_groups, + config().group_mode, nccl_stream_id(), + GetAsyncStreamKind())); TF_ASSIGN_OR_RETURN( size_t num_local_participants, GetNumLocalParticipants(*params.collective_params, @@ -398,10 +398,10 @@ absl::Status NcclCollectiveThunk::Initialize(const InitializeParams& params) { } namespace { -// Wrap NcclCliqueKey into a unique struct to guarantee we do not accidentally +// Wrap GpuCliqueKey into a unique struct to guarantee we do not accidentally // try to run multiple unrelated rendezvous for a same key. struct FirstCallRendezvousKey { - NcclCliqueKey clique_key; + GpuCliqueKey clique_key; template friend H AbslHashValue(H h, const FirstCallRendezvousKey& key) { @@ -453,9 +453,9 @@ absl::Status NcclCollectiveThunk::ExecuteOnStream(const ExecuteParams& params) { // ahead on one rank leads to deadlocks in NCCL. if (NeedFirstCallRendzevous() && !first_call_rendezvous_flag_.IsCompleted()) { TF_ASSIGN_OR_RETURN( - NcclCliqueKey clique_key, - GetNcclCliqueKey(*params.collective_params, config().replica_groups, - config().group_mode, stream_id, stream_kind)); + GpuCliqueKey clique_key, + GetGpuCliqueKey(*params.collective_params, config().replica_groups, + config().group_mode, stream_id, stream_kind)); TF_ASSIGN_OR_RETURN( size_t num_local_participants, diff --git a/xla/service/gpu/runtime/nccl_collective_thunk.h b/xla/service/gpu/runtime/nccl_collective_thunk.h index 48ffbd4bd8604..c45d0514895d8 100644 --- a/xla/service/gpu/runtime/nccl_collective_thunk.h +++ b/xla/service/gpu/runtime/nccl_collective_thunk.h @@ -32,6 +32,7 @@ limitations under the License. #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Value.h" +#include "xla/backends/gpu/collectives/gpu_clique_key.h" #include "xla/backends/gpu/collectives/gpu_collectives.h" #include "xla/core/collectives/communicator.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -43,7 +44,6 @@ limitations under the License. #include "xla/service/gpu/buffer_allocations.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/runtime/nccl_api.h" -#include "xla/service/gpu/runtime/nccl_clique_key.h" #include "xla/service/gpu/runtime/thunk.h" #include "xla/service/llvm_ir/llvm_util.h" #include "xla/service/rendezvous.h" @@ -282,7 +282,7 @@ absl::Status AddOpDescription(absl::Status status, OpT op, //===----------------------------------------------------------------------===// -absl::StatusOr GetNcclCliqueKey( +absl::StatusOr GetGpuCliqueKey( const Thunk::CollectiveExecuteParams& params, const std::vector& replica_groups, CollectiveOpGroupMode group_mode, CollectiveStreamId stream_id, diff --git a/xla/service/gpu/runtime/nccl_p2p_thunk_common.cc b/xla/service/gpu/runtime/nccl_p2p_thunk_common.cc index 494fde4a32850..12e7f8e411ee9 100644 --- a/xla/service/gpu/runtime/nccl_p2p_thunk_common.cc +++ b/xla/service/gpu/runtime/nccl_p2p_thunk_common.cc @@ -26,11 +26,11 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/synchronization/mutex.h" #include "mlir/IR/BuiltinAttributes.h" +#include "xla/backends/gpu/collectives/gpu_clique_key.h" #include "xla/executable_run_options.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/parser/hlo_parser.h" #include "xla/service/collective_ops_utils.h" -#include "xla/service/gpu/runtime/nccl_clique_key.h" #include "xla/shape.h" #include "xla/status_macros.h" #include "xla/stream_executor/stream_executor.h" diff --git a/xla/service/gpu/runtime/nccl_p2p_thunk_common.h b/xla/service/gpu/runtime/nccl_p2p_thunk_common.h index 568abfba905ac..5b8f6f5cc3fab 100644 --- a/xla/service/gpu/runtime/nccl_p2p_thunk_common.h +++ b/xla/service/gpu/runtime/nccl_p2p_thunk_common.h @@ -27,8 +27,8 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/synchronization/mutex.h" #include "mlir/IR/BuiltinAttributes.h" +#include "xla/backends/gpu/collectives/gpu_clique_key.h" #include "xla/hlo/ir/hlo_instructions.h" -#include "xla/service/gpu/runtime/nccl_clique_key.h" #include "xla/service/gpu/runtime/nccl_collective_thunk.h" #include "xla/shape.h" diff --git a/xla/service/gpu/runtime/nccl_recv_thunk.h b/xla/service/gpu/runtime/nccl_recv_thunk.h index de2dd0bbf5d80..b0c6c16b8d4e6 100644 --- a/xla/service/gpu/runtime/nccl_recv_thunk.h +++ b/xla/service/gpu/runtime/nccl_recv_thunk.h @@ -21,10 +21,10 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/string_view.h" +#include "xla/backends/gpu/collectives/gpu_clique_key.h" #include "xla/core/collectives/communicator.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/service/gpu/runtime/nccl_api.h" -#include "xla/service/gpu/runtime/nccl_clique_key.h" #include "xla/service/gpu/runtime/nccl_collective_thunk.h" #include "xla/service/gpu/runtime/nccl_p2p_thunk_common.h" #include "xla/stream_executor/stream.h" diff --git a/xla/service/gpu/runtime/nccl_send_thunk.h b/xla/service/gpu/runtime/nccl_send_thunk.h index 6f6f59618dc96..795a505f5aadd 100644 --- a/xla/service/gpu/runtime/nccl_send_thunk.h +++ b/xla/service/gpu/runtime/nccl_send_thunk.h @@ -21,11 +21,11 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/string_view.h" +#include "xla/backends/gpu/collectives/gpu_clique_key.h" #include "xla/core/collectives/communicator.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/gpu/runtime/nccl_api.h" -#include "xla/service/gpu/runtime/nccl_clique_key.h" #include "xla/service/gpu/runtime/nccl_collective_thunk.h" #include "xla/service/gpu/runtime/nccl_p2p_thunk_common.h" #include "xla/stream_executor/stream.h" diff --git a/xla/service/gpu/runtime/thunk.cc b/xla/service/gpu/runtime/thunk.cc index 380276baef37a..e7ccc1e360e3b 100644 --- a/xla/service/gpu/runtime/thunk.cc +++ b/xla/service/gpu/runtime/thunk.cc @@ -29,6 +29,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/backends/gpu/collectives/gpu_clique_key.h" #include "xla/core/collectives/communicator.h" #include "xla/core/collectives/rank_id.h" #include "xla/executable_run_options.h" @@ -39,7 +40,6 @@ limitations under the License. #include "xla/service/gpu/buffer_allocations.h" #include "xla/service/gpu/gpu_executable_run_options.h" #include "xla/service/gpu/runtime/nccl_clique.h" -#include "xla/service/gpu/runtime/nccl_clique_key.h" #include "xla/service/service_executable_run_options.h" #include "xla/stream_executor/stream.h" #include "tsl/platform/statusor.h" @@ -56,7 +56,7 @@ Thunk::CollectiveCliques::CollectiveCliques( : cliques_map_(std::move(cliques_map)) {} absl::StatusOr Thunk::CollectiveCliques::GetComm( - const NcclCliqueKey& clique_key, RankId rank) const { + const GpuCliqueKey& clique_key, RankId rank) const { // Check that we locked access to a clique for `clique_key`. auto clique = cliques_map_.find(clique_key); if (clique == cliques_map_.end()) { @@ -76,7 +76,7 @@ absl::StatusOr Thunk::CollectiveCliques::GetComm( } absl::StatusOr Thunk::CollectiveCliques::is_local_clique( - const NcclCliqueKey& clique_key) const { + const GpuCliqueKey& clique_key) const { // Check that we locked access to a clique for `clique_key`. auto clique = cliques_map_.find(clique_key); if (clique == cliques_map_.end()) { @@ -88,7 +88,7 @@ absl::StatusOr Thunk::CollectiveCliques::is_local_clique( } absl::StatusOr Thunk::CollectiveCliques::num_communicators( - const NcclCliqueKey& clique_key) const { + const GpuCliqueKey& clique_key) const { // Check that we locked access to a clique for `clique_key`. auto clique = cliques_map_.find(clique_key); if (clique == cliques_map_.end()) { diff --git a/xla/service/gpu/runtime/thunk.h b/xla/service/gpu/runtime/thunk.h index 2294e6fa0c20f..7ccc463cd8b9f 100644 --- a/xla/service/gpu/runtime/thunk.h +++ b/xla/service/gpu/runtime/thunk.h @@ -31,6 +31,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/backends/gpu/collectives/gpu_clique_key.h" #include "xla/core/collectives/communicator.h" #include "xla/core/collectives/rank_id.h" #include "xla/executable_run_options.h" @@ -42,7 +43,6 @@ limitations under the License. #include "xla/service/gpu/gpu_executable_run_options.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/runtime/nccl_clique.h" -#include "xla/service/gpu/runtime/nccl_clique_key.h" #include "xla/service/service_executable_run_options.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" @@ -200,7 +200,7 @@ class Thunk { class ResourceRequests { public: virtual ~ResourceRequests() = default; - virtual absl::Status AddClique(const NcclCliqueKey& clique_key, + virtual absl::Status AddClique(const GpuCliqueKey& clique_key, int32_t num_local_participants) = 0; }; @@ -215,16 +215,16 @@ class Thunk { CollectiveCliques() = default; explicit CollectiveCliques(NcclClique::AcquiredCliquesMap cliques_map); - absl::StatusOr GetComm(const NcclCliqueKey& clique_key, + absl::StatusOr GetComm(const GpuCliqueKey& clique_key, RankId rank) const; // Returns the number of communicators in a collective clique. Returns error // if we do not have an acquired clique for a given key. absl::StatusOr num_communicators( - const NcclCliqueKey& clique_key) const; + const GpuCliqueKey& clique_key) const; // Returns whether the clique is a local clique. - absl::StatusOr is_local_clique(const NcclCliqueKey& clique_key) const; + absl::StatusOr is_local_clique(const GpuCliqueKey& clique_key) const; bool empty() const { return cliques_map_.empty(); }