Skip to content

Commit

Permalink
added ConvBfloat16Support HLO pass
Browse files Browse the repository at this point in the history
  • Loading branch information
pemeliya authored and hsharsha committed Mar 5, 2024
1 parent b30236d commit 426e59e
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 1 deletion.
37 changes: 37 additions & 0 deletions xla/service/gpu/amdgpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ limitations under the License.
#include "xla/service/algebraic_simplifier.h"
#include "xla/service/call_inliner.h"
#include "xla/service/dot_dimension_merger.h"
#include "xla/service/float_normalization.h"
#include "xla/service/gpu/conv_algorithm_picker.h"
#include "xla/service/gpu/cublas_pad_for_gemms.h"
#include "xla/service/gpu/cublas_padding_requirements.h"
Expand Down Expand Up @@ -53,6 +54,36 @@ limitations under the License.
namespace xla {
namespace gpu {

namespace {

struct ConvBfloat16Support : public FloatSupport {

explicit ConvBfloat16Support(
const se::RocmComputeCapability& rocm)
: FloatSupport(BF16),
// TODO: MIOpen does not support bf16 convolutions yet
is_conv_bf16_supported_(rocm.has_bf16_dtype_support()) {}

bool SupportsLowPrecisionOperand(const HloInstruction& hlo,
int64_t operand_index) const override {
return (hlo.opcode() != HloOpcode::kConvolution) || is_conv_bf16_supported_;
}

bool SupportsLowPrecisionOutput(const HloInstruction& hlo) const override {
return (hlo.opcode() != HloOpcode::kConvolution) || is_conv_bf16_supported_;
}

bool SupportsMixedPrecisions(const HloInstruction& hlo) const override {
// Skip all HLOs other than convolutions.
return (hlo.opcode() != HloOpcode::kConvolution);
}

private:
bool is_conv_bf16_supported_;
};

} // namespace

absl::Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization(
HloModule* hlo_module, se::GpuComputeCapability gpu_version,
se::dnn::VersionInfo dnn_version,
Expand All @@ -63,6 +94,12 @@ absl::Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization(
pipeline.AddInvariantCheckerDebug<HloVerifier>(
/*layout_sensitive=*/false,
/*allow_mixed_precision=*/false);

// Convert upsupported bf16 convolutions to f32.
ConvBfloat16Support conv_bf16_support(
std::get<se::RocmComputeCapability>(gpu_version));
pipeline.AddPass<FloatNormalization>(&conv_bf16_support);

pipeline.AddPass<GpusolverRewriter>();
pipeline.AddPass<GpuConvRewriter>();
pipeline.AddPass<GpuConvPaddingLegalization>();
Expand Down
2 changes: 1 addition & 1 deletion xla/stream_executor/rocm/rocm_driver.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ namespace stream_executor {
namespace gpu {
// Formats hipError_t to output prettified values into a log stream.
// Error summaries taken from:
string ToString(hipError_t result);
std::string ToString(hipError_t result);

// GpuContext wraps the device_ordinal and hipCtx_t handle.
class GpuContext {
Expand Down

0 comments on commit 426e59e

Please sign in to comment.