Skip to content

Commit

Permalink
Add CudnnPadForConvolutions and CudnnVecotrizeConvolutions HLO pass
Browse files Browse the repository at this point in the history
  • Loading branch information
hsharsha committed Mar 6, 2024
1 parent 426e59e commit 8f57a81
Show file tree
Hide file tree
Showing 8 changed files with 47 additions and 16 deletions.
3 changes: 3 additions & 0 deletions xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3907,6 +3907,9 @@ cc_library(
":conv_algorithm_picker",
":cublas_pad_for_gemms",
":cublas_padding_requirements",
":cudnn_pad_for_convolutions",
":cudnn_simplify_padding",
":cudnn_vectorize_convolutions",
":cusolver_rewriter",
":gemm_algorithm_picker",
":gemm_rewriter",
Expand Down
17 changes: 15 additions & 2 deletions xla/service/gpu/amdgpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ limitations under the License.
#include "xla/service/gpu/conv_algorithm_picker.h"
#include "xla/service/gpu/cublas_pad_for_gemms.h"
#include "xla/service/gpu/cublas_padding_requirements.h"
#include "xla/service/gpu/cudnn_pad_for_convolutions.h"
#include "xla/service/gpu/cudnn_simplify_padding.h"
#include "xla/service/gpu/cudnn_vectorize_convolutions.h"
#include "xla/service/gpu/cusolver_rewriter.h"
#include "xla/service/gpu/gemm_algorithm_picker.h"
#include "xla/service/gpu/gemm_rewriter.h"
Expand Down Expand Up @@ -88,6 +91,8 @@ absl::Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization(
HloModule* hlo_module, se::GpuComputeCapability gpu_version,
se::dnn::VersionInfo dnn_version,
se::DeviceMemoryAllocator* device_allocator) {
auto rocm_compute_capability =
std::get<se::RocmComputeCapability>(gpu_version);
// Convert convolutions into CustomCalls to MIOpen, then canonicalize them
// (PadInsertion).
HloPassPipeline pipeline("conv_canonicalization");
Expand All @@ -96,13 +101,14 @@ absl::Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization(
/*allow_mixed_precision=*/false);

// Convert upsupported bf16 convolutions to f32.
ConvBfloat16Support conv_bf16_support(
std::get<se::RocmComputeCapability>(gpu_version));
ConvBfloat16Support conv_bf16_support(rocm_compute_capability);
pipeline.AddPass<FloatNormalization>(&conv_bf16_support);

pipeline.AddPass<GpusolverRewriter>();
pipeline.AddPass<GpuConvRewriter>();
pipeline.AddPass<GpuConvPaddingLegalization>();
pipeline.AddPass<CudnnPadForConvolutions>(rocm_compute_capability);
pipeline.AddPass<CudnnVectorizeConvolutions>(rocm_compute_capability);

// The conv padding/vectorization passes which we need to get rid of. They
// also leave behind unnecessary tuple/get-tuple-element pairs that
Expand All @@ -119,6 +125,13 @@ absl::Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization(
options.set_enable_unconditional_reduce_of_concat_replacement(false);
pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>(options);

// CudnnSimplifyPadding gets rid of some padding introduced by
// CudnnPadForConvolutions and used by CudnnVectorizeConvolutions. The
// pattern-matches in this pass need to be run after inlining and simplifying
// tuples from CudnnVectorizeConvolutions. We also need to run algsimp to
// e.g. clean up unnecessary nop `convert`s.
pipeline.AddPass<CudnnSimplifyPadding>();

pipeline.AddPass<HloConstantFolding>();
TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());

Expand Down
11 changes: 8 additions & 3 deletions xla/service/gpu/cudnn_pad_for_convolutions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ static absl::StatusOr<bool> TryResolvePaddedShapesForTensorCore(
// Adds padding to cudnn integer convolutions to make input and output feature
// maps multiples of pad_to (usually 4 or 32).
absl::StatusOr<bool> TryResolvePaddedShapesForIntegerConvolution(
int pad_to, const se::CudaComputeCapability& compute_capability,
int pad_to, const se::GpuComputeCapability& compute_capability,
HloCustomCallInstruction* conv, std::vector<Shape>* new_input_shapes_ptr,
Shape* new_result_shape_ptr) {
TF_ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(conv));
Expand Down Expand Up @@ -490,13 +490,16 @@ absl::StatusOr<bool> CudnnPadForConvolutions::Run(
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads) {
bool changed = false;
bool isCUDA = std::holds_alternative<se::CudaComputeCapability>(compute_capability_);
bool isROCm = std::holds_alternative<se::RocmComputeCapability>(compute_capability_);
for (HloComputation* comp :
module->MakeNonfusionComputations(execution_threads)) {
for (HloCustomCallInstruction* conv : GetRelevantConvs(comp)) {
// On Turing and later (sm75+), pad to multiples of 32 bytes if possible,
// because that lets us use the fast int8x32 data type.
bool local_changed = false;
if (compute_capability_.IsAtLeast(7, 5)) {
bool isSM75gpu = isCUDA && std::get<se::CudaComputeCapability>(compute_capability_).IsAtLeast(7, 5);
if (isSM75gpu || isROCm) {
TF_ASSIGN_OR_RETURN(
local_changed,
ResolveAndPad(conv, absl::bind_front(
Expand All @@ -512,7 +515,9 @@ absl::StatusOr<bool> CudnnPadForConvolutions::Run(
}
changed |= local_changed;
}
if (compute_capability_.IsAtLeast(se::CudaComputeCapability::VOLTA)) {
bool isVOLTA = isCUDA &&
std::get<se::CudaComputeCapability>(compute_capability_).IsAtLeast(se::CudaComputeCapability::VOLTA);
if (isVOLTA || isROCm) {
for (HloCustomCallInstruction* conv : GetRelevantConvs(comp)) {
TF_ASSIGN_OR_RETURN(
bool local_changed,
Expand Down
4 changes: 2 additions & 2 deletions xla/service/gpu/cudnn_pad_for_convolutions.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ namespace gpu {
// add slice instruction to remove unnecessary output features.
class CudnnPadForConvolutions : public HloModulePass {
public:
explicit CudnnPadForConvolutions(se::CudaComputeCapability compute_capability)
explicit CudnnPadForConvolutions(se::GpuComputeCapability compute_capability)
: compute_capability_(compute_capability) {}

absl::string_view name() const override {
Expand All @@ -44,7 +44,7 @@ class CudnnPadForConvolutions : public HloModulePass {
const absl::flat_hash_set<absl::string_view>& execution_threads) override;

private:
const se::CudaComputeCapability compute_capability_;
const se::GpuComputeCapability compute_capability_;
};

} // namespace gpu
Expand Down
10 changes: 6 additions & 4 deletions xla/service/gpu/cudnn_support_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ namespace xla {
namespace gpu {

absl::StatusOr<bool> CudnnSupportsOptimizedIntegerConvolution(
const se::CudaComputeCapability& compute_capability,
const se::GpuComputeCapability& compute_capability,
HloCustomCallInstruction& conv, int vector_size) {
TF_ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(&conv));
const Shape& input_shape = conv.operand(0)->shape();
Expand All @@ -50,9 +50,11 @@ absl::StatusOr<bool> CudnnSupportsOptimizedIntegerConvolution(

// Require cc6.1+ for any vectorized integer convolutions
// Require cc7.5+ for any IMMA convolutions
if ((vector_size == 32 && !compute_capability.IsAtLeast(7, 5)) ||
!compute_capability.IsAtLeast(6, 1)) {
VLOG(3) << "Compute capability " << compute_capability.ToString()
bool isCUDA = std::holds_alternative<se::CudaComputeCapability>(compute_capability);
auto cuda_compute_capability = std::get<se::CudaComputeCapability>(compute_capability);
if ((vector_size == 32 && !cuda_compute_capability.IsAtLeast(7, 5)) ||
!cuda_compute_capability.IsAtLeast(6, 1)) {
VLOG(3) << "Compute capability " << cuda_compute_capability.ToString()
<< " is not sufficent for int8x" << vector_size
<< " vectorization.";
return false;
Expand Down
2 changes: 1 addition & 1 deletion xla/service/gpu/cudnn_support_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ namespace gpu {
// This function does not guarantee that a convolution will be padded and/or
// vectorized. It only checks that it is a valid candiate for such optimization.
absl::StatusOr<bool> CudnnSupportsOptimizedIntegerConvolution(
const se::CudaComputeCapability& compute_capability,
const se::GpuComputeCapability& compute_capability,
HloCustomCallInstruction& conv, int vector_size);

// Represents configuration for the reshape-transpose-reshape operations that
Expand Down
9 changes: 6 additions & 3 deletions xla/service/gpu/cudnn_vectorize_convolutions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ absl::Status ReorderInt8NchwVect(HloCustomCallInstruction* conv,
// (The dimensions can appear in any order; which is N/C/etc is determined by
// the convolutions' dnums.)
static absl::StatusOr<bool> TryRevectorizeConv(
const se::CudaComputeCapability& compute_capability,
const se::GpuComputeCapability& compute_capability,
const se::dnn::VersionInfo& cudnn_version, HloCustomCallInstruction* conv,
int vect_size) {
const Shape& input_shape = conv->operand(0)->shape();
Expand Down Expand Up @@ -496,7 +496,7 @@ static absl::StatusOr<bool> TryRevectorizeConv(
// This requires that C be a multiple of vect_size. CudnnPadForConvolutions can
// add padding to make this true.
static absl::StatusOr<bool> TryVectorizeConv(
const se::CudaComputeCapability& compute_capability,
const se::GpuComputeCapability& compute_capability,
const se::dnn::VersionInfo& cudnn_version, HloCustomCallInstruction* conv,
int64_t vect_size) {
const Shape& input_shape = conv->operand(0)->shape();
Expand Down Expand Up @@ -625,7 +625,10 @@ absl::StatusOr<bool> CudnnVectorizeConvolutions::Run(
// Try to (re)vectorize to int8x32 if this is an sm75+ GPU. If we can't,
// fall back to int8x4.
bool local_changed = false;
if (compute_capability_.IsAtLeast(7, 5)) {
bool isSM75gpu = std::holds_alternative<se::CudaComputeCapability>(compute_capability_)
&& std::get<se::CudaComputeCapability>(compute_capability_).IsAtLeast(7, 5);
bool isROCm = std::holds_alternative<se::RocmComputeCapability>(compute_capability_);
if (isSM75gpu || isROCm) {
TF_ASSIGN_OR_RETURN(
local_changed,
TryRevectorizeConv(compute_capability_, cudnn_version_, conv, 32));
Expand Down
7 changes: 6 additions & 1 deletion xla/service/gpu/cudnn_vectorize_convolutions.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ class CudnnVectorizeConvolutions : public HloModulePass {
: compute_capability_(compute_capability),
cudnn_version_(cudnn_version) {}

explicit CudnnVectorizeConvolutions(
se::RocmComputeCapability compute_capability)
: compute_capability_(compute_capability) {}


absl::string_view name() const override {
return "cudnn_vectorize_convolutions";
}
Expand All @@ -61,7 +66,7 @@ class CudnnVectorizeConvolutions : public HloModulePass {
const absl::flat_hash_set<absl::string_view>& execution_threads) override;

private:
const se::CudaComputeCapability compute_capability_;
const se::GpuComputeCapability compute_capability_;
const se::dnn::VersionInfo cudnn_version_;
};

Expand Down

0 comments on commit 8f57a81

Please sign in to comment.