From ec6823b18f7802e8b5cae4b87eca66d340a351e6 Mon Sep 17 00:00:00 2001 From: ghostplant Date: Sun, 10 Oct 2021 05:01:13 +0000 Subject: [PATCH] wrap fp16 to ROCm-supported dtype in amdgpu (#22) --- README.md | 2 +- tutel/custom/custom_kernel.cpp | 2 +- tutel/impls/fast_dispatch.py | 8 +++++--- tutel/impls/moe_layer.py | 2 +- 4 files changed, 8 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 9f977d5f..0d1045c9 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ Tutel MoE: An Optimized Mixture-of-Experts Implementation. - Supported Framework: Pytorch -- Supported GPUs: CUDA(fp32 + fp16), ROCm(fp32) +- Supported GPUs: CUDA(fp32 + fp16), ROCm(fp32 + fp16) How to setup Tutel MoE for Pytorch: ``` diff --git a/tutel/custom/custom_kernel.cpp b/tutel/custom/custom_kernel.cpp index 7dcc855c..3c406c95 100644 --- a/tutel/custom/custom_kernel.cpp +++ b/tutel/custom/custom_kernel.cpp @@ -20,7 +20,7 @@ #define CHECK_EQ(x, y) AT_ASSERTM((x) == (y), "CHECK_EQ fails.") #define CHECK_NE(x, y) AT_ASSERTM((x) != (y), "CHECK_NE fails.") -#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor") #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") static std::string file_read(const char *path) { diff --git a/tutel/impls/fast_dispatch.py b/tutel/impls/fast_dispatch.py index 2cacf9fd..6254ad59 100644 --- a/tutel/impls/fast_dispatch.py +++ b/tutel/impls/fast_dispatch.py @@ -6,6 +6,7 @@ import torch from torch import Tensor +from .jit_compiler import IS_HIP_EXTENSION from ..jit_kernels import sparse as jit_kernel class GatingEncoder(torch.autograd.Function): @@ -65,7 +66,8 @@ def __init__(self, num_global_experts, capacity, model_dim, dispatch_dtype): self.capacity = capacity self.model_dim = model_dim self.kernel_pool = dict() - self.dtype = dispatch_dtype + self.dtype = dispatch_dtype if not IS_HIP_EXTENSION else torch.float32 + self.original_dtype = dispatch_dtype self.aligned_dim = model_dim // (2 if self.dtype == torch.float16 else 1) def update(self, indices_, locations_, gates_, capacity=None): @@ -87,9 +89,9 @@ def update(self, indices_, locations_, gates_, capacity=None): self.func_fwd, self.func_bwd_data, self.func_bwd_gate, self.ones_helper = self.kernel_pool[tuple((sample_size, capacity))] def encode(self, data): - return GatingEncoder.apply(self, data) + return GatingEncoder.apply(self, data.to(self.dtype)).to(self.original_dtype) def decode(self, data): - return GatingDecoder.apply(self, data, *self.gates_) + return GatingDecoder.apply(self, data.to(self.dtype), *self.gates_).to(self.original_dtype) fast_dispatcher = TutelMoeFastDispatcher diff --git a/tutel/impls/moe_layer.py b/tutel/impls/moe_layer.py index 6451181b..96fba74a 100644 --- a/tutel/impls/moe_layer.py +++ b/tutel/impls/moe_layer.py @@ -275,7 +275,7 @@ def forward(self, input: Tensor, **kwargs: Any): if reshaped_input.size(0) > self.expected_sample_size: raise Exception('MoE JIT is designed to work on sample size = %s, while receiving sample size = %s (> %s)' % (self.expected_sample_size, reshaped_input.size(0), self.expected_sample_size)) else: - if get_world_rank(expert_group) == 0: + if get_world_rank(self.expert_group) == 0: print('[WARN] MoE is initialized to keep working on sample size = %s, while receiving sample size = %s (will slow down this forward step)' % (self.expected_sample_size, reshaped_input.size(0))) pad_input = torch.zeros([self.expected_sample_size, self.model_dim], dtype=reshaped_input.dtype, layout=reshaped_input.layout, device=reshaped_input.device) pad_input[:reshaped_input.size(0)] = reshaped_input