Skip to content

Commit

Permalink
[xla:collectives] NFC: Remove NcclCliqueKey alias
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 702392147
  • Loading branch information
ezhulenev authored and Google-ML-Automation committed Dec 3, 2024
1 parent fab1800 commit a7c38b7
Show file tree
Hide file tree
Showing 27 changed files with 88 additions and 146 deletions.
3 changes: 2 additions & 1 deletion xla/backends/gpu/collectives/BUILD
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured")
load("//xla:xla.bzl", "xla_cc_test")
load("//xla/stream_executor:build_defs.bzl", "if_gpu_is_configured")
load("//xla/tsl:tsl.bzl", "internal_visibility")
load("//xla/tsl/platform:rules_cc.bzl", "cc_library")
load("//xla/tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured")

package(
# copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
default_visibility = [":friends"],
default_visibility = internal_visibility([":friends"]),
licenses = ["notice"],
)

Expand Down
2 changes: 1 addition & 1 deletion xla/pjrt/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -219,12 +219,12 @@ cc_library(
deps = [
"//xla:status_macros",
"//xla:util",
"//xla/backends/gpu/collectives:gpu_clique_key",
"//xla/core/collectives:clique_id",
"//xla/core/collectives:clique_key",
"//xla/pjrt/distributed:key_value_store_interface",
"//xla/service:global_device_id",
"//xla/service/gpu/runtime:nccl_api",
"//xla/service/gpu/runtime:nccl_clique_key",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/status:statusor",
Expand Down
4 changes: 2 additions & 2 deletions xla/pjrt/gpu/nccl_id_store.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ limitations under the License.
#include "absl/status/statusor.h"
#include "absl/synchronization/mutex.h"
#include "absl/time/time.h"
#include "xla/backends/gpu/collectives/gpu_clique_key.h"
#include "xla/core/collectives/clique_id.h"
#include "xla/core/collectives/clique_key.h"
#include "xla/service/gpu/runtime/nccl_api.h"
#include "xla/service/gpu/runtime/nccl_clique_key.h"
#include "xla/status_macros.h"
#include "xla/util.h"
#include "tsl/platform/casts.h"
Expand All @@ -34,7 +34,7 @@ limitations under the License.
namespace xla {

absl::StatusOr<CliqueId> NcclIdStore::GetNcclUniqueId(const CliqueKey& key) {
auto* gpu_key = tsl::down_cast<const gpu::NcclCliqueKey*>(&key);
auto* gpu_key = tsl::down_cast<const gpu::GpuCliqueKey*>(&key);
if (gpu_key == nullptr) {
return InvalidArgument("Expected GPU clique key");
}
Expand Down
6 changes: 3 additions & 3 deletions xla/pjrt/gpu/nccl_id_store.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@ limitations under the License.
#include "absl/container/flat_hash_map.h"
#include "absl/status/statusor.h"
#include "absl/synchronization/mutex.h"
#include "xla/backends/gpu/collectives/gpu_clique_key.h"
#include "xla/core/collectives/clique_id.h"
#include "xla/core/collectives/clique_key.h"
#include "xla/pjrt/distributed/key_value_store_interface.h"
#include "xla/service/global_device_id.h"
#include "xla/service/gpu/runtime/nccl_clique_key.h"

namespace xla {

// A table mapping NcclCliqueKeys to NcclCliqueIds. In a distributed setup the
// A table mapping GpuCliqueKeys to CliqueIds. In a distributed setup the
// table of NCCL IDs is kept on the master node (node 0). The node of the first
// participating device will create the unique id.
class NcclIdStore {
Expand All @@ -51,7 +51,7 @@ class NcclIdStore {
const std::shared_ptr<KeyValueStoreInterface> kv_store_;

absl::Mutex mu_;
absl::flat_hash_map<gpu::NcclCliqueKey, CliqueId> cache_ ABSL_GUARDED_BY(mu_);
absl::flat_hash_map<gpu::GpuCliqueKey, CliqueId> cache_ ABSL_GUARDED_BY(mu_);
};

} // namespace xla
Expand Down
6 changes: 3 additions & 3 deletions xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,10 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
"//xla:executable_run_options",
"//xla/backends/gpu/collectives:gpu_clique_key",
"//xla/core/collectives:clique_id",
"//xla/core/collectives:clique_key",
"//xla/service:global_device_id",
"//xla/service/gpu/runtime:nccl_clique_key",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
],
Expand Down Expand Up @@ -339,6 +339,7 @@ cc_library(
"//xla:status_macros",
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/backends/gpu/collectives:gpu_clique_key",
"//xla/ffi:attribute_map",
"//xla/ffi:ffi_api",
"//xla/ffi/api:c_api",
Expand Down Expand Up @@ -375,7 +376,6 @@ cc_library(
"//xla/service/gpu/runtime:nccl_all_to_all_thunk",
"//xla/service/gpu/runtime:nccl_api",
"//xla/service/gpu/runtime:nccl_clique",
"//xla/service/gpu/runtime:nccl_clique_key",
"//xla/service/gpu/runtime:nccl_collective_broadcast_thunk",
"//xla/service/gpu/runtime:nccl_collective_permute_thunk",
"//xla/service/gpu/runtime:nccl_collective_thunk",
Expand Down Expand Up @@ -575,6 +575,7 @@ cc_library(
"//xla:shape_util",
"//xla:status_macros",
"//xla:util",
"//xla/backends/gpu/collectives:gpu_clique_key",
"//xla/hlo/ir:hlo",
"//xla/service:buffer_assignment",
"//xla/service:executable",
Expand All @@ -587,7 +588,6 @@ cc_library(
"//xla/service/gpu/runtime:annotation",
"//xla/service/gpu/runtime:for_all_thunks",
"//xla/service/gpu/runtime:nccl_clique",
"//xla/service/gpu/runtime:nccl_clique_key",
"//xla/service/gpu/runtime:sequential_thunk",
"//xla/service/gpu/runtime:thunk",
"//xla/stream_executor:device_description",
Expand Down
2 changes: 1 addition & 1 deletion xla/service/gpu/fusions/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ cc_library(
"//xla:status_macros",
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/backends/gpu/collectives:gpu_clique_key",
"//xla/ffi:attribute_map",
"//xla/ffi:ffi_api",
"//xla/hlo/ir:hlo",
Expand All @@ -95,7 +96,6 @@ cc_library(
"//xla/service/gpu/runtime:kernel_thunk",
"//xla/service/gpu/runtime:nccl_all_reduce_thunk",
"//xla/service/gpu/runtime:nccl_api",
"//xla/service/gpu/runtime:nccl_clique_key",
"//xla/service/gpu/runtime:nccl_collective_thunk",
"//xla/service/gpu/runtime:thunk",
"@com_google_absl//absl/algorithm:container",
Expand Down
2 changes: 1 addition & 1 deletion xla/service/gpu/fusions/custom.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ limitations under the License.
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Support/LLVM.h"
#include "xla/backends/gpu/collectives/gpu_clique_key.h"
#include "xla/ffi/attribute_map.h"
#include "xla/ffi/ffi_api.h"
#include "xla/hlo/ir/hlo_casting_utils.h"
Expand Down Expand Up @@ -64,7 +65,6 @@ limitations under the License.
#include "xla/service/gpu/runtime/kernel_thunk.h"
#include "xla/service/gpu/runtime/nccl_all_reduce_thunk.h"
#include "xla/service/gpu/runtime/nccl_api.h"
#include "xla/service/gpu/runtime/nccl_clique_key.h"
#include "xla/service/gpu/runtime/nccl_collective_thunk.h"
#include "xla/service/gpu/runtime/thunk.h"
#include "xla/service/gpu/stream_executor_util.h"
Expand Down
8 changes: 4 additions & 4 deletions xla/service/gpu/gpu_executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ limitations under the License.
#include "absl/synchronization/mutex.h"
#include "absl/time/time.h"
#include "absl/types/span.h"
#include "xla/backends/gpu/collectives/gpu_clique_key.h"
#include "xla/executable_run_options.h"
#include "xla/hlo/ir/hlo_input_output_alias_config.h"
#include "xla/hlo/ir/hlo_instruction.h"
Expand All @@ -51,7 +52,6 @@ limitations under the License.
#include "xla/service/gpu/runtime/annotation.h"
#include "xla/service/gpu/runtime/for_all_thunks.h"
#include "xla/service/gpu/runtime/nccl_clique.h"
#include "xla/service/gpu/runtime/nccl_clique_key.h"
#include "xla/service/gpu/runtime/sequential_thunk.h"
#include "xla/service/gpu/runtime/thunk.h"
#include "xla/service/gpu/stream_executor_util.h"
Expand Down Expand Up @@ -187,7 +187,7 @@ namespace {
// Shared resources required for thunk initialization and execution.
class ResourceRequests : public Thunk::ResourceRequests {
public:
absl::Status AddClique(const NcclCliqueKey& clique_key,
absl::Status AddClique(const GpuCliqueKey& clique_key,
int32_t num_local_participants) final {
VLOG(5) << "Add collective clique request: " << clique_key.ToString()
<< "; num_local_participants: " << num_local_participants;
Expand Down Expand Up @@ -286,7 +286,7 @@ class ResourceRequests : public Thunk::ResourceRequests {

private:
struct CliqueRequest {
NcclCliqueKey key;
GpuCliqueKey key;
int64_t num_local_participants;
int64_t id;
};
Expand Down Expand Up @@ -326,7 +326,7 @@ class ResourceRequests : public Thunk::ResourceRequests {
return cliques;
}

absl::flat_hash_map<NcclCliqueKey, CliqueRequest> cliques_;
absl::flat_hash_map<GpuCliqueKey, CliqueRequest> cliques_;
};

absl::Status MaybeSyncAndProfile(const ServiceExecutableRunOptions* run_options,
Expand Down
2 changes: 1 addition & 1 deletion xla/service/gpu/ir_emitter_unnested.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ limitations under the License.
#include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Export.h"
#include "xla/backends/gpu/collectives/gpu_clique_key.h"
#include "xla/ffi/api/c_api.h"
#include "xla/ffi/attribute_map.h"
#include "xla/ffi/ffi_api.h"
Expand Down Expand Up @@ -126,7 +127,6 @@ limitations under the License.
#include "xla/service/gpu/runtime/nccl_all_to_all_thunk.h"
#include "xla/service/gpu/runtime/nccl_api.h"
#include "xla/service/gpu/runtime/nccl_clique.h"
#include "xla/service/gpu/runtime/nccl_clique_key.h"
#include "xla/service/gpu/runtime/nccl_collective_broadcast_thunk.h"
#include "xla/service/gpu/runtime/nccl_collective_permute_thunk.h"
#include "xla/service/gpu/runtime/nccl_collective_thunk.h"
Expand Down
48 changes: 11 additions & 37 deletions xla/service/gpu/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ cc_library(
":nccl_all_reduce_thunk",
":nccl_all_to_all_thunk",
":nccl_api",
":nccl_clique_key",
":nccl_collective_broadcast_thunk",
":nccl_collective_thunk",
":thunk",
Expand All @@ -72,6 +71,7 @@ cc_library(
"//xla:status_macros",
"//xla:types",
"//xla:util",
"//xla/backends/gpu/collectives:gpu_clique_key",
"//xla/backends/gpu/collectives:gpu_collectives",
"//xla/core/collectives:communicator",
"//xla/ffi:call_frame",
Expand Down Expand Up @@ -212,10 +212,10 @@ cc_library(
hdrs = ["nccl_api.h"],
compatible_with = get_compatible_with_portable(),
deps = [
":nccl_clique_key",
"//xla:shape_util",
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/backends/gpu/collectives:gpu_clique_key",
"//xla/backends/gpu/collectives:gpu_collectives",
"//xla/backends/gpu/collectives:nccl_communicator",
"//xla/core/collectives:clique_id",
Expand Down Expand Up @@ -257,9 +257,9 @@ cc_library(
hdrs = ["nccl_api.h"],
compatible_with = get_compatible_with_portable(),
deps = [
":nccl_clique_key",
"//xla:shape_util",
"//xla:xla_data_proto_cc",
"//xla/backends/gpu/collectives:gpu_clique_key",
"//xla/backends/gpu/collectives:gpu_collectives",
"//xla/core/collectives:clique_id",
"//xla/core/collectives:communicator",
Expand All @@ -282,12 +282,12 @@ cc_library(
hdrs = ["nccl_clique.h"],
deps = [
":nccl_api",
":nccl_clique_key",
"//xla:debug_options_flags",
"//xla:executable_run_options",
"//xla:status_macros",
"//xla:types",
"//xla:util",
"//xla/backends/gpu/collectives:gpu_clique_key",
"//xla/core/collectives:clique_id",
"//xla/core/collectives:clique_key",
"//xla/core/collectives:communicator",
Expand Down Expand Up @@ -317,32 +317,6 @@ cc_library(
],
)

cc_library(
name = "nccl_clique_key",
hdrs = ["nccl_clique_key.h"],
compatible_with = get_compatible_with_portable(),
deps = [
"//xla/backends/gpu/collectives:gpu_clique_key",
"//xla/backends/gpu/collectives:gpu_collectives",
"//xla/core/collectives:clique_id",
"//xla/core/collectives:clique_key",
"//xla/core/collectives:rank_id",
"//xla/service:global_device_id",
"//xla/tsl/lib/gtl:int_type",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/crc:crc32c",
"@com_google_absl//absl/hash",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/types:span",
"@tsl//tsl/platform:casts",
"@tsl//tsl/platform:logging",
],
)

#===-------------------------------------------------------------------------------------------===//
# XLA Thunks Runtime
#===-------------------------------------------------------------------------------------------===//
Expand Down Expand Up @@ -868,11 +842,11 @@ cc_library(
hdrs = ["nccl_all_to_all_thunk.h"],
deps = [
":nccl_api",
":nccl_clique_key",
":nccl_collective_thunk",
":thunk",
"//xla:shape_util",
"//xla:status_macros",
"//xla/backends/gpu/collectives:gpu_clique_key",
"//xla/backends/gpu/collectives:gpu_collectives",
"//xla/core/collectives:communicator",
"//xla/hlo/ir:hlo",
Expand Down Expand Up @@ -946,12 +920,12 @@ cc_library(
deps = [
":nccl_api",
":nccl_clique",
":nccl_clique_key",
":thunk",
"//xla:debug_options_flags",
"//xla:shape_util",
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/backends/gpu/collectives:gpu_clique_key",
"//xla/backends/gpu/collectives:gpu_collectives",
"//xla/core/collectives:communicator",
"//xla/core/collectives:rank_id",
Expand Down Expand Up @@ -995,12 +969,12 @@ cc_library(
srcs = ["nccl_p2p_thunk_common.cc"],
hdrs = ["nccl_p2p_thunk_common.h"],
deps = [
":nccl_clique_key",
":nccl_collective_thunk",
"//xla:executable_run_options",
"//xla:shape_util",
"//xla:status_macros",
"//xla:xla_data_proto_cc",
"//xla/backends/gpu/collectives:gpu_clique_key",
"//xla/core/collectives:communicator",
"//xla/hlo/ir:hlo",
"//xla/hlo/parser:hlo_parser",
Expand All @@ -1023,11 +997,11 @@ cc_library(
hdrs = ["nccl_recv_thunk.h"],
deps = [
":nccl_api",
":nccl_clique_key",
":nccl_collective_thunk",
":nccl_p2p_thunk_common",
":thunk",
"//xla:status_macros",
"//xla/backends/gpu/collectives:gpu_clique_key",
"//xla/core/collectives:communicator",
"//xla/hlo/ir:hlo",
"//xla/service:collective_ops_utils",
Expand All @@ -1049,11 +1023,11 @@ cc_library(
hdrs = ["nccl_send_thunk.h"],
deps = [
":nccl_api",
":nccl_clique_key",
":nccl_collective_thunk",
":nccl_p2p_thunk_common",
":thunk",
"//xla:status_macros",
"//xla/backends/gpu/collectives:gpu_clique_key",
"//xla/core/collectives:communicator",
"//xla/hlo/ir:hlo",
"//xla/service:collective_ops_utils",
Expand All @@ -1075,9 +1049,9 @@ cc_library(
hdrs = ["nccl_group_thunk.h"],
deps = [
":nccl_api",
":nccl_clique_key",
":nccl_collective_thunk",
":thunk",
"//xla/backends/gpu/collectives:gpu_clique_key",
"//xla/hlo/ir:hlo",
"//xla/stream_executor:stream",
"@com_google_absl//absl/status",
Expand Down Expand Up @@ -1189,8 +1163,8 @@ cc_library(
hdrs = ["thunk.h"],
deps = [
":nccl_clique",
":nccl_clique_key",
"//xla:executable_run_options",
"//xla/backends/gpu/collectives:gpu_clique_key",
"//xla/core/collectives:communicator",
"//xla/core/collectives:rank_id",
"//xla/ffi:execution_context",
Expand Down
Loading

0 comments on commit a7c38b7

Please sign in to comment.