From c4bc2b28d72d5d94be8d982c8b141dc7b099c4e8 Mon Sep 17 00:00:00 2001 From: Allan Renucci Date: Thu, 12 Dec 2024 07:36:36 -0800 Subject: [PATCH] [XLA:GPU] Readability and performance nits. PiperOrigin-RevId: 705498812 --- xla/service/gpu/BUILD | 1 - xla/service/gpu/gpu_hlo_schedule.cc | 186 +++++++++++++--------------- 2 files changed, 87 insertions(+), 100 deletions(-) diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index 8ca13cee5e735..4d7f53665c17b 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -2113,7 +2113,6 @@ cc_library( "//xla/hlo/utils:hlo_query", "//xla/service:buffer_value", "//xla/service:collective_ops_utils", - "//xla/service:collective_utils", "//xla/service:latency_hiding_scheduler", "//xla/service:p2p_schedule_preparation", "//xla/service:profile_guided_latency_estimator", diff --git a/xla/service/gpu/gpu_hlo_schedule.cc b/xla/service/gpu/gpu_hlo_schedule.cc index b29e25980d57f..a85bbace1fe69 100644 --- a/xla/service/gpu/gpu_hlo_schedule.cc +++ b/xla/service/gpu/gpu_hlo_schedule.cc @@ -33,6 +33,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/match.h" #include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_computation.h" @@ -278,16 +279,15 @@ SchedulerConfig GetSchedulerConfig(int64_t memory_limit, tensorflow::profiler::ProfiledInstructionsProto GetProfileForFingerprint( tensorflow::profiler::ProfiledInstructionsProto& profile, - const std::string& fingerprint) { + absl::string_view fingerprint) { tensorflow::profiler::ProfiledInstructionsProto result; bool merge_remat_clones = false; for (const auto& cost : profile.costs()) { - absl::string_view cost_name = cost.name(); std::string new_cost_name = cost.name(); absl::string_view cost_sep = "::"; - if (absl::StrContains(cost_name, cost_sep)) { - std::vector split_names = - absl::StrSplit(cost_name, cost_sep); + if (absl::StrContains(cost.name(), cost_sep)) { + std::vector split_names = + absl::StrSplit(cost.name(), cost_sep); if (split_names.size() != 2 || split_names[0] != fingerprint) { continue; } @@ -325,30 +325,33 @@ tensorflow::profiler::ProfiledInstructionsProto GetProfileForFingerprint( return name; }; - // Map from stripped name -> pair - absl::flat_hash_map> costs; + struct Data { + double accumulated_cost = 0.0; + int64_t count = 0; + }; + absl::flat_hash_map costs; for (const auto& cost : result.costs()) { - std::pair& data = costs[strip_remat_suffix(cost.name())]; - data.first += cost.cost_us(); - data.second++; + Data& data = costs[strip_remat_suffix(cost.name())]; + data.accumulated_cost += cost.cost_us(); + data.count++; } tensorflow::profiler::ProfiledInstructionsProto merged_result; - for (const auto& cost : costs) { + for (const auto& [name, data] : costs) { auto* new_cost = merged_result.add_costs(); - double average = cost.second.first / cost.second.second; + double average = data.accumulated_cost / data.count; new_cost->set_cost_us(average); - new_cost->set_name(std::string(cost.first)); + new_cost->set_name(std::string(name)); } return merged_result; } std::optional ReadPGLEProfile( - const HloModule* module, const std::string& fingerprint) { + const HloModule& module, absl::string_view fingerprint) { tensorflow::profiler::ProfiledInstructionsProto profile; - absl::string_view fdo_profile = module->config().fdo_profile(); + absl::string_view fdo_profile = module.config().fdo_profile(); // First attempt to read the profile from `fdo_profile` in ModuleConfig if (!fdo_profile.empty()) { // Attempt to parse it as a binary proto. @@ -369,14 +372,14 @@ std::optional ReadPGLEProfile( } const std::string& pgle_profile_file_or_dir_path = - module->config() + module.config() .debug_options() .xla_gpu_pgle_profile_file_or_directory_path(); if (pgle_profile_file_or_dir_path.empty()) { return std::nullopt; } tsl::Env* env = tsl::Env::Default(); - auto read_text_or_binary_profile = [&profile, env, &fingerprint]( + auto read_text_or_binary_profile = [&profile, env, fingerprint]( const std::string& text_path, const std::string& binary_path) -> std::optional { @@ -409,7 +412,7 @@ std::optional ReadPGLEProfile( // specific module. if (env->IsDirectory(pgle_profile_file_or_dir_path).ok()) { std::string pgle_profile_path_prefix = - pgle_profile_file_or_dir_path + "/" + fingerprint; + absl::StrCat(pgle_profile_file_or_dir_path, "/", fingerprint); return read_text_or_binary_profile(pgle_profile_path_prefix + ".pbtxt", pgle_profile_path_prefix + ".pb"); } @@ -446,7 +449,7 @@ std::string TagWithFingerprint(HloModule* module) { HloPrintOptions::Canonical().set_print_backend_config(true)); FrontendAttributes attributes; (*attributes.mutable_map())[std::string(kFingerprintBeforeLHS)] = fingerprint; - module->add_frontend_attributes(attributes); + module->add_frontend_attributes(std::move(attributes)); VLOG(1) << "Fingerprint before LHS for module " << module->name() << "(" << module->unique_id() << ") = " << fingerprint; return fingerprint; @@ -457,16 +460,16 @@ std::string TagWithFingerprint(HloModule* module) { // additionally add fail-fast/warn checks to the pipeline which act in the // absence of instruction in the profile. See `PGLEAccuracyChecker` for details. std::unique_ptr GetLatencyEstimator( - HloModule* module, int pointer_size, + const HloModule& module, int pointer_size, const se::DeviceDescription& gpu_device_info, absl::string_view fingerprint, const SchedulerConfig& config, HloPassPipeline& pipeline) { - const DebugOptions& options = module->config().debug_options(); + const DebugOptions& options = module.config().debug_options(); auto gpu_latency_estimator = std::make_unique(pointer_size); std::optional profile = - ReadPGLEProfile(module, std::string(fingerprint)); + ReadPGLEProfile(module, fingerprint); if (profile.has_value()) { auto aggregator = std::make_unique(); @@ -491,7 +494,7 @@ std::unique_ptr GetLatencyEstimator( [input_pointer_size = pointer_size](const Shape& shape) { return GetSizeOfShape(shape, input_pointer_size); }, - module->entry_computation()); + module.entry_computation()); } return gpu_latency_estimator; } @@ -520,7 +523,7 @@ absl::Status RunLatencyHidingSchedulerPasses( HloPassPipeline pipeline("latency-hiding-scheduler"); std::unique_ptr latency_estimator = GetLatencyEstimator( - module, pointer_size, gpu_device_info, fingerprint, config, pipeline); + *module, pointer_size, gpu_device_info, fingerprint, config, pipeline); auto scheduler_core = std::make_unique( shape_size_in_bytes, async_tracker.get(), latency_estimator.get(), config, @@ -538,11 +541,56 @@ absl::Status RunLatencyHidingSchedulerPasses( return pipeline.Run(module).status(); } -} // end namespace +// Compute the device memory limit to be used by passes like scheduler and +// HLO rematerialization. +int64_t GetSchedulerMemoryLimit(const HloModule& module, + const se::DeviceDescription& gpu_device_info, + int pointer_size) { + // There is a "base" value which is either specified in HloModuleConfig (this + // value should take into account the fact that we need to leave some memory + // free for allocations that happen outside of XLA's allocator) or + // obtained from GPU device info (we scale down this value to leave some space + // for these outside XLA's allocator allocation). + // + // From that base value, subtract any input and output sizes (assuming they + // are live throughout the execution) and then apply a slop factor. + const int64_t base_limit = + module.config().device_memory_size() != 0 + ? module.config().device_memory_size() + : gpu_device_info.device_memory_size() * 80 / 100; -static int64_t GetSchedulerMemoryLimit( - const HloModule* module, const se::DeviceDescription& gpu_device_info, - int pointer_size); + // Find the total size of inputs and outputs. + int64_t total_io_size = 0; + for (HloInstruction* param : + module.entry_computation()->parameter_instructions()) { + ShapeUtil::ForEachSubshape( + param->shape(), + [&](const Shape& subshape, const ShapeIndex& /*index*/) { + total_io_size += GetSizeOfShape(subshape, pointer_size); + }); + } + ShapeUtil::ForEachSubshape( + module.result_shape(), + [&](const Shape& subshape, const ShapeIndex& /*index*/) { + total_io_size += GetSizeOfShape(subshape, pointer_size); + }); + + // If any inputs and outputs are aliased, do not double count them. + module.input_output_alias_config().ForEachAlias( + [&](const ShapeIndex& output_index, + const HloInputOutputAliasConfig::Alias&) { + const Shape& subshape = + ShapeUtil::GetSubshape(module.result_shape(), output_index); + total_io_size -= GetSizeOfShape(subshape, pointer_size); + }); + + int64_t limit = + (base_limit - total_io_size) * + module.config().debug_options().xla_gpu_memory_limit_slop_factor() / 100; + return limit; +} + +} // end namespace absl::StatusOr ScheduleGpuModule( HloModule* module, int64_t pointer_size, @@ -553,21 +601,21 @@ absl::StatusOr ScheduleGpuModule( // instruction name with ids. std::string fingerprint = TagWithFingerprint(module); int64_t memory_limit = - GetSchedulerMemoryLimit(module, gpu_device_info, pointer_size); + GetSchedulerMemoryLimit(*module, gpu_device_info, pointer_size); - // Case 1: Module has a schedule. - // - // Return already existing schedule. + // Module already has a schedule, do nothing. if (module->has_schedule()) { return ScheduleMetadata{memory_limit}; } - // Case 2: Module does not have a schedule. - // - // Running default scheduler. + // Run the scheduler which minimizes peak memory usage. // We need to run it anyway because LHS relies on it track buffers. See // `xla::BufferInfoTracker::BufferInfoTracker()`. TF_RETURN_IF_ERROR(RunP2PSchedulePreparation(module)); + TF_ASSIGN_OR_RETURN( + HloSchedule schedule, + ScheduleGpuModuleWithMemoryScheduler(module, pointer_size)); + TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule))); bool enable_latency_hiding_scheduler = module->config() @@ -575,23 +623,12 @@ absl::StatusOr ScheduleGpuModule( .xla_gpu_enable_latency_hiding_scheduler() || IsPassEnabledAtOptimizationEffort(*module); - // Default behaviour. Run the scheduler which minimizes peak memory usage. - TF_ASSIGN_OR_RETURN( - HloSchedule schedule, - ScheduleGpuModuleWithMemoryScheduler(module, pointer_size)); - TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule))); - - // LHS disabled, we return a default schedule. - if (!enable_latency_hiding_scheduler) { - return ScheduleMetadata{memory_limit}; - } - - // Case 3: LHS enabled. - // // Run Latency Hiding Scheduler (LHS). It maximizes the compute-communication // overlap, potentially at the cost of memory usage. - TF_RETURN_IF_ERROR(RunLatencyHidingSchedulerPasses( - module, pointer_size, fingerprint, memory_limit, gpu_device_info)); + if (enable_latency_hiding_scheduler) { + TF_RETURN_IF_ERROR(RunLatencyHidingSchedulerPasses( + module, pointer_size, fingerprint, memory_limit, gpu_device_info)); + } return ScheduleMetadata{memory_limit}; } @@ -614,54 +651,5 @@ HloInstructionSequence PostProcessSchedule( return PostprocessorToScheduleAsEarlyOrLateAsPossible(result); } -// Compute the device memory limit to be used by passes like scheduler and -// HLO rematerialization. -static int64_t GetSchedulerMemoryLimit( - const HloModule* module, const se::DeviceDescription& gpu_device_info, - int pointer_size) { - // There is a "base" value which is either specified in HloModuleConfig (this - // value should take into account the fact that we need to leave some memory - // free for allocations that happen outside of XLA's allocator) or - // obtained from GPU device info (we scale down this value to leave some space - // for these outside XLA's allocator allocation). - // - // From that base value, subtract any input and output sizes (assuming they - // are live throughout the execution) and then apply a slop factor. - const int64_t base_limit = - module->config().device_memory_size() != 0 - ? module->config().device_memory_size() - : gpu_device_info.device_memory_size() * 80 / 100; - - // Find the total size of inputs and outputs. - int64_t total_io_size = 0; - for (HloInstruction* param : - module->entry_computation()->parameter_instructions()) { - ShapeUtil::ForEachSubshape( - param->shape(), - [&](const Shape& subshape, const ShapeIndex& /*index*/) { - total_io_size += GetSizeOfShape(subshape, pointer_size); - }); - } - ShapeUtil::ForEachSubshape( - module->result_shape(), - [&](const Shape& subshape, const ShapeIndex& /*index*/) { - total_io_size += GetSizeOfShape(subshape, pointer_size); - }); - - // If any inputs and outputs are aliased, do not double count them. - module->input_output_alias_config().ForEachAlias( - [&](const ShapeIndex& output_index, - const HloInputOutputAliasConfig::Alias&) { - const Shape& subshape = - ShapeUtil::GetSubshape(module->result_shape(), output_index); - total_io_size -= GetSizeOfShape(subshape, pointer_size); - }); - - int64_t limit = - (base_limit - total_io_size) * - module->config().debug_options().xla_gpu_memory_limit_slop_factor() / 100; - return limit; -} - } // namespace gpu } // namespace xla