From 426e59e131feef1b3d6b2736f455803e1e58a8c3 Mon Sep 17 00:00:00 2001 From: Pavel Emeliyanenko Date: Fri, 19 Jan 2024 14:47:57 +0000 Subject: [PATCH] added ConvBfloat16Support HLO pass --- xla/service/gpu/amdgpu_compiler.cc | 37 ++++++++++++++++++++++++++ xla/stream_executor/rocm/rocm_driver.h | 2 +- 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/xla/service/gpu/amdgpu_compiler.cc b/xla/service/gpu/amdgpu_compiler.cc index 0a985871914cd..4eff13885d8b1 100644 --- a/xla/service/gpu/amdgpu_compiler.cc +++ b/xla/service/gpu/amdgpu_compiler.cc @@ -24,6 +24,7 @@ limitations under the License. #include "xla/service/algebraic_simplifier.h" #include "xla/service/call_inliner.h" #include "xla/service/dot_dimension_merger.h" +#include "xla/service/float_normalization.h" #include "xla/service/gpu/conv_algorithm_picker.h" #include "xla/service/gpu/cublas_pad_for_gemms.h" #include "xla/service/gpu/cublas_padding_requirements.h" @@ -53,6 +54,36 @@ limitations under the License. namespace xla { namespace gpu { +namespace { + +struct ConvBfloat16Support : public FloatSupport { + + explicit ConvBfloat16Support( + const se::RocmComputeCapability& rocm) + : FloatSupport(BF16), + // TODO: MIOpen does not support bf16 convolutions yet + is_conv_bf16_supported_(rocm.has_bf16_dtype_support()) {} + + bool SupportsLowPrecisionOperand(const HloInstruction& hlo, + int64_t operand_index) const override { + return (hlo.opcode() != HloOpcode::kConvolution) || is_conv_bf16_supported_; + } + + bool SupportsLowPrecisionOutput(const HloInstruction& hlo) const override { + return (hlo.opcode() != HloOpcode::kConvolution) || is_conv_bf16_supported_; + } + + bool SupportsMixedPrecisions(const HloInstruction& hlo) const override { + // Skip all HLOs other than convolutions. + return (hlo.opcode() != HloOpcode::kConvolution); + } + + private: + bool is_conv_bf16_supported_; +}; + +} // namespace + absl::Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization( HloModule* hlo_module, se::GpuComputeCapability gpu_version, se::dnn::VersionInfo dnn_version, @@ -63,6 +94,12 @@ absl::Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization( pipeline.AddInvariantCheckerDebug( /*layout_sensitive=*/false, /*allow_mixed_precision=*/false); + + // Convert upsupported bf16 convolutions to f32. + ConvBfloat16Support conv_bf16_support( + std::get(gpu_version)); + pipeline.AddPass(&conv_bf16_support); + pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); diff --git a/xla/stream_executor/rocm/rocm_driver.h b/xla/stream_executor/rocm/rocm_driver.h index f94d9bf8f0c52..4cb76dd8151bf 100644 --- a/xla/stream_executor/rocm/rocm_driver.h +++ b/xla/stream_executor/rocm/rocm_driver.h @@ -29,7 +29,7 @@ namespace stream_executor { namespace gpu { // Formats hipError_t to output prettified values into a log stream. // Error summaries taken from: -string ToString(hipError_t result); +std::string ToString(hipError_t result); // GpuContext wraps the device_ordinal and hipCtx_t handle. class GpuContext {