From bb316e4e6f8a145bf8e00ec3efbcd57add386a06 Mon Sep 17 00:00:00 2001 From: Wei CUI Date: Tue, 14 Sep 2021 02:55:37 +0000 Subject: [PATCH] initial tutel version Signed-off-by: Wei CUI --- README.md | 46 +++- examples/helloworld.py | 125 +++++++++++ setup.py | 76 +++++++ tutel/__init__.py | 3 + tutel/custom/__init__.py | 3 + tutel/custom/custom_kernel.cpp | 122 +++++++++++ tutel/impls/__init__.py | 3 + tutel/impls/fast_dispatch.py | 94 +++++++++ tutel/impls/jit_compiler.py | 51 +++++ tutel/impls/moe_layer.py | 371 +++++++++++++++++++++++++++++++++ tutel/jit_kernels/__init__.py | 3 + tutel/jit_kernels/gating.py | 79 +++++++ tutel/jit_kernels/sparse.py | 132 ++++++++++++ tutel/moe.py | 12 ++ tutel/test/eval_bgemm.py | 95 +++++++++ 15 files changed, 1207 insertions(+), 8 deletions(-) create mode 100755 examples/helloworld.py create mode 100755 setup.py create mode 100644 tutel/__init__.py create mode 100644 tutel/custom/__init__.py create mode 100644 tutel/custom/custom_kernel.cpp create mode 100644 tutel/impls/__init__.py create mode 100644 tutel/impls/fast_dispatch.py create mode 100644 tutel/impls/jit_compiler.py create mode 100644 tutel/impls/moe_layer.py create mode 100644 tutel/jit_kernels/__init__.py create mode 100644 tutel/jit_kernels/gating.py create mode 100644 tutel/jit_kernels/sparse.py create mode 100644 tutel/moe.py create mode 100755 tutel/test/eval_bgemm.py diff --git a/README.md b/README.md index 5cd7cecf..96669b2b 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,44 @@ -# Project +# Project Tutel -> This repo has been populated by an initial template to help get you started. Please -> make sure to update the content to build a great experience for community-building. +Tutel MoE: An Optimized Mixture-of-Experts Implementation. -As the maintainer of this project, please make a few updates: +- Supported Framework: Pytorch +- Supported GPUs: CUDA(fp32 + fp16), ROCm(fp32) -- Improving this README.MD file to provide a great experience -- Updating SUPPORT.MD with content about this project's support experience -- Understanding the security reporting process in SECURITY.MD -- Remove this section from the README +How to setup Tutel MoE for Pytorch: +``` +* Install Online: + + $ python3 -m pip install --user https://github.com/microsoft/tutel/releases/download/v0.1.0/tutel-0.1.0.tar.gz + +* Build from Source: + + $ git clone https://github.com/microsoft/tutel + $ python3 ./tutel/setup.py install --user +``` + +How to use Tutel-optimized MoE in Pytorch: +``` +* Tutel MoE Example: + + moe_layer = MOELayer('Top2Gate', model_dim, experts={'type': 'ffn', 'hidden_size_per_expert': 1024}) + y = moe_layer(x) + +* Usage of MOELayer Args: + + gate : the string type of MOE gate, e.g: Top1Gate, Top2Gate + model_dim : the number of channels for MOE's input tensor + experts : a dict-type config for builtin expert network, or a torch.nn.Module-type custom expert network + fp32_gate : option of enabling mixed precision for gate network + scan_expert_func : allow users to specify a lambda function to iterate each experts param, e.g. `scan_expert_func = lambda name, param: setattr(param, 'expert', True)` + result_func : allow users to specify a lambda function to format the MoE output and aux_loss, e.g. `result_func = lambda output: (output, output.l_aux)` + group : specify the explicit communication group of all_to_all + seeds : a tuple containing a pair of int to specify manual seed of (shared params, local params) + +* Running MoE Hello World Model: + + $ python3 -m torch.distributed.launch --nproc_per_node=1 ./examples/helloworld.py +``` ## Contributing diff --git a/examples/helloworld.py b/examples/helloworld.py new file mode 100755 index 00000000..ae31edc8 --- /dev/null +++ b/examples/helloworld.py @@ -0,0 +1,125 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import time +import torch +import torch.optim as optim +import torch.nn.functional as F +import torch.distributed as dist +from torch import nn +import argparse + +from tutel import moe as tutel_moe + +parser = argparse.ArgumentParser() + +parser.add_argument('--local_rank', type=int, default=0) +parser.add_argument('--batch_size', type=int, default=4) +parser.add_argument('--num_tokens', type=int, default=512) +parser.add_argument('--model_dim', type=int, default=2048) +parser.add_argument('--hidden_size', type=int, default=1024) +parser.add_argument('--num_local_experts', type=int, default=2) +parser.add_argument('--dtype', type=str, default='float32') +parser.add_argument('--fp32_gate', default=False, action='store_true') +parser.add_argument('--top', type=int, default=2) +args = parser.parse_args() + +torch.cuda.set_device(args.local_rank) + +try: + if dist.is_available(): + dist.init_process_group('nccl') + dist_rank = dist.get_rank() + dist_world_size = dist.get_world_size() + + def dist_print(*args): + if dist_rank == 0: + print(*args) +except: + dist_rank = 0 + dist_world_size = 1 + dist_print = print + +batch_size = args.batch_size +num_tokens = args.num_tokens +model_dim = args.model_dim +hidden_size = args.hidden_size +num_local_experts = args.num_local_experts +top_value = args.top +local_rank = args.local_rank + + +device = torch.device('cuda', args.local_rank) + +if args.dtype == 'float32': + torch.set_default_dtype(torch.float32) +elif args.dtype == 'float16': + torch.set_default_dtype(torch.float16) +else: + raise Exception('Unrecognized data type specified: %s' % args.dtype) + + +class ExampleModel(torch.nn.Module): + def __init__(self): + super().__init__() + + self._moe_layer = tutel_moe.moe_layer( + gate_type = 'Top%dGate' % top_value, + model_dim = model_dim, + experts = {'type': 'ffn', 'count_per_node': num_local_experts, 'hidden_size_per_expert': hidden_size}, + fp32_gate = args.fp32_gate, + scan_expert_func = lambda name, param: setattr(param, 'expert', True), + seeds = (1, dist_rank + 1), + ).to(device) + + # Distinguish different parameter types: gate, local_experts + local_count = sum([torch.numel(param) for name, param in self._moe_layer.get_parameter_iterator(param_type='local_experts')]) + shared_count = sum([torch.numel(param) for name, param in self._moe_layer.get_parameter_iterator(param_type='gate')]) + dist_print('[Statistics] param count for MoE local_experts = %s, param count for MoE gate = %s.\n' % (local_count, shared_count)) + + def forward(self, input): + result = self._moe_layer(input) + result = F.log_softmax(torch.sum(result, dim=2), dim=1) + return result + +model = ExampleModel() +dist_print(model) + +optimizer = torch.optim.SGD(model.parameters(), lr=1e-5) + +x = torch.randn([batch_size, num_tokens, model_dim], device=device, requires_grad=True) +y = torch.LongTensor(batch_size).random_(1).to(device) + +tuples = (dist_world_size, args.dtype, model_dim, hidden_size, batch_size * num_tokens, num_local_experts, top_value, device) +dist_print('[Benchmark] world_size = %s, dtype = %s, model_dim = %s, hidden_size = %s, samples = %s, num_local_experts = %s, topK = %s, device = `%s`' % tuples) + +average_time, num_steps = 0, 100 + +params_for_all_reduce = [p for p in model.parameters() if not hasattr(p, 'expert') and getattr(p, 'requires_grad', False)] + +for i in range(num_steps): + + torch.cuda.synchronize() + t_start = time.time() + optimizer.zero_grad() + + output = model(x) + loss = F.nll_loss(output, y) + loss.backward() + if dist_world_size > 1: + for p in params_for_all_reduce: + p.grad /= dist_world_size + dist.all_reduce(p.grad) + optimizer.step() + + torch.cuda.synchronize() + t_stop = time.time() + dist_print('STEP-%s: DONE, loss = %s, step_time = %s sec.' % (i, float(loss.data), t_stop - t_start)) + + if i + 10 >= num_steps: + average_time += t_stop - t_start + +average_time /= 10 +dist_print('\n[Summary] Average synchronized step_time = %s sec.' % average_time) diff --git a/setup.py b/setup.py new file mode 100755 index 00000000..f18f29b2 --- /dev/null +++ b/setup.py @@ -0,0 +1,76 @@ +#!/usr/bin/env python3 +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""The setuptools based setup module. + +Reference: + https://packaging.python.org/guides/distributing-packages-using-setuptools/ +""" + +import os, sys + +from setuptools import setup, find_packages +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +try: + from torch.utils.cpp_extension import IS_HIP_EXTENSION +except: + IS_HIP_EXTENSION = False + +if len(sys.argv) <= 1: + sys.argv += ['install', '--user'] + +root_path = os.path.dirname(sys.argv[0]) +root_path = root_path if root_path else '.' + +os.chdir(root_path) + +setup( + name='tutel', + version='0.1.0', + description='An Optimized Mixture-of-Experts Implementation.', + url='https://github.com/microsoft/Tutel', + author='Microsoft', + author_email='tutel@microsoft.com', + license='MIT', + classifiers=[ + 'Development Status :: 2 - Pre-Alpha', + 'Environment :: GPU', + 'Intended Audience :: Developers', + 'Intended Audience :: Education', + 'Intended Audience :: Science/Research', + 'License :: OSI Approved :: MIT License', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3 :: Only', + 'Programming Language :: Python :: 3.5', + 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', + 'Programming Language :: Python :: 3.9', + 'Topic :: Scientific/Engineering', + 'Topic :: Scientific/Engineering :: Artificial Intelligence', + 'Topic :: Software Development', + 'Topic :: Software Development :: Libraries', + 'Topic :: Software Development :: Libraries :: Python Modules', + ], + keywords=['Mixture of Experts', 'MoE', 'Optimization'], + packages=find_packages(), + python_requires='>=3.6, <4', + install_requires=[ + ], + ext_modules=[ + CUDAExtension('tutel_custom_kernel', [ + './tutel/custom/custom_kernel.cpp', + ], + library_dirs=['/usr/local/cuda/lib64/stubs'], + libraries=['dl', 'cuda'] if not IS_HIP_EXTENSION else []) + ], + cmdclass={ + 'build_ext': BuildExtension + }, + project_urls={ + 'Source': 'https://github.com/microsoft/Tutel', + 'Tracker': 'https://github.com/microsoft/Tutel/issues', + }, +) diff --git a/tutel/__init__.py b/tutel/__init__.py new file mode 100644 index 00000000..c45e0a75 --- /dev/null +++ b/tutel/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + diff --git a/tutel/custom/__init__.py b/tutel/custom/__init__.py new file mode 100644 index 00000000..c45e0a75 --- /dev/null +++ b/tutel/custom/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + diff --git a/tutel/custom/custom_kernel.cpp b/tutel/custom/custom_kernel.cpp new file mode 100644 index 00000000..f32b3520 --- /dev/null +++ b/tutel/custom/custom_kernel.cpp @@ -0,0 +1,122 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include +#include + +#include + +#include +#include +#include +#include + +#undef CHECK_EQ +#undef CHECK_NE +#undef CHECK_CUDA +#undef CHECK_CONTIGUOUS + +#define CHECK_EQ(x, y) AT_ASSERTM((x) == (y), "CHECK_EQ fails.") +#define CHECK_NE(x, y) AT_ASSERTM((x) != (y), "CHECK_NE fails.") +#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") + + +static void invoke(const std::vector &ts, int _key) { + struct ModuleConfig { + CUmodule hMod = nullptr; + CUfunction hFunc = nullptr; + + dim3 blocks, threads; + }; + + static std::vector gpuMods; + +#if !defined(__HIP_PLATFORM_HCC__) +#if 0 + static void *libcuda = nullptr; + static int (*cuModuleLoad)(...) = nullptr; + static int (*cuModuleGetFunction)(...) = nullptr; + static int (*cuLaunchKernel)(...) = nullptr; + + if (libcuda == nullptr) { + (libcuda == nullptr ? (libcuda = dlopen("/usr/lib/x86_64-linux-gnu/libcuda.so.1", RTLD_LAZY | RTLD_GLOBAL)) : 0); + (libcuda == nullptr ? (libcuda = dlopen("/usr/lib/x86_64-linux-gnu/libcuda.so", RTLD_LAZY | RTLD_GLOBAL)) : 0); + (libcuda == nullptr ? (libcuda = dlopen("/usr/local/lib/x86_64-linux-gnu/libcuda.so.1", RTLD_LAZY | RTLD_GLOBAL)) : 0); + (libcuda == nullptr ? (libcuda = dlopen("/usr/local/lib/x86_64-linux-gnu/libcuda.so", RTLD_LAZY | RTLD_GLOBAL)) : 0); + (libcuda == nullptr ? (libcuda = dlopen("/usr/local/cuda/lib64/libcuda.so.1", RTLD_LAZY | RTLD_GLOBAL)) : 0); + (libcuda == nullptr ? (libcuda = dlopen("/usr/local/cuda/lib64/libcuda.so", RTLD_LAZY | RTLD_GLOBAL)) : 0); + (libcuda == nullptr ? (libcuda = dlopen("/usr/local/cuda/lib64/stubs/libcuda.so", RTLD_LAZY | RTLD_GLOBAL)) : 0); + + CHECK_NE(nullptr, libcuda); + CHECK_NE(nullptr, (cuModuleLoad = (decltype(cuModuleLoad))dlsym(libcuda, "cuModuleLoad"))); + CHECK_NE(nullptr, (cuModuleGetFunction = (decltype(cuModuleGetFunction))dlsym(libcuda, "cuModuleGetFunction"))); + CHECK_NE(nullptr, (cuLaunchKernel = (decltype(cuLaunchKernel))dlsym(libcuda, "cuLaunchKernel"))); + } +#endif +#endif + + int key_int = (_key & 255), ctx = _key >> 8; + if (ctx >= (int)gpuMods.size()) + gpuMods.resize(ctx + 1); + + auto &gm = gpuMods[ctx]; + if (gm.hFunc == nullptr) { + std::string key = std::to_string(key_int); + std::string file_name = "/tmp/" + std::to_string(ctx) + "-" + key + ".cu"; + FILE *fp = fopen(file_name.c_str(), "rb"); + CHECK_EQ(true, fp != nullptr); + fseek(fp, 0, SEEK_END); + size_t code_size = ftell(fp); + fseek(fp, 0, SEEK_SET); + std::vector code(code_size + 1); + CHECK_EQ(code_size, fread((void*)code.data(), 1, code_size, fp)); + fclose(fp); + + int dev = key_int; + CHECK_EQ(0, cudaSetDevice(dev)); + +#if !defined(__HIP_PLATFORM_HCC__) + std::string cc = "30"; + int major, minor; + CHECK_EQ(0, cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, dev)); + CHECK_EQ(0, cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, dev)); + std::string arch = std::to_string(major) + std::to_string(minor); + CHECK_EQ(0, system(("/usr/local/cuda/bin/nvcc " + file_name + " -o " + file_name + ".fatbin --fatbin -O2 -gencode arch=compute_" + arch + ",code=sm_" + arch).c_str())); +#else + hipDeviceProp_t prop; + CHECK_EQ(0, hipGetDeviceProperties(&prop, dev)); + std::string arch = std::to_string(prop.gcnArch); + CHECK_EQ(0, system(("/opt/rocm/bin/hipcc " + file_name + " -o " + file_name + ".fatbin --genco -O2 -w --amdgpu-target=gfx" + arch).c_str())); +#endif + CHECK_EQ(0, cuModuleLoad(&gm.hMod, (file_name + ".fatbin").c_str())); + + const char *source = code.data(), *pos, *tail; + CHECK_EQ(true, nullptr != (pos = strstr(source, " void "))); + pos += 6; CHECK_EQ(true, nullptr != (tail = strchr(pos, '('))); + + CHECK_EQ(0, cuModuleGetFunction(&gm.hFunc, gm.hMod, std::string(pos, tail - pos).c_str())); + CHECK_EQ(true, nullptr != gm.hFunc); + + { char tag[] = "// [thread_extent] blockIdx.x = "; pos = strstr(source, tag); gm.blocks.x = pos ? std::atoi(pos + sizeof(tag) - 1) : 1; } + { char tag[] = "// [thread_extent] blockIdx.y = "; pos = strstr(source, tag); gm.blocks.y = pos ? std::atoi(pos + sizeof(tag) - 1) : 1; } + { char tag[] = "// [thread_extent] blockIdx.z = "; pos = strstr(source, tag); gm.blocks.z = pos ? std::atoi(pos + sizeof(tag) - 1) : 1; } + { char tag[] = "// [thread_extent] threadIdx.x = "; pos = strstr(source, tag); gm.threads.x = pos ? std::atoi(pos + sizeof(tag) - 1) : 1; } + { char tag[] = "// [thread_extent] threadIdx.y = "; pos = strstr(source, tag); gm.threads.y = pos ? std::atoi(pos + sizeof(tag) - 1) : 1; } + { char tag[] = "// [thread_extent] threadIdx.z = "; pos = strstr(source, tag); gm.threads.z = pos ? std::atoi(pos + sizeof(tag) - 1) : 1; } + } + + std::vector pargs(ts.size()), ppargs(ts.size()); + for (int i = 0; i < (int)ts.size(); ++i) { + pargs[i] = (void*)ts[i].data_ptr(), ppargs[i] = &pargs[i]; + } + + CHECK_EQ(0, cuLaunchKernel(gm.hFunc, gm.blocks.x, gm.blocks.y, gm.blocks.z, gm.threads.x, gm.threads.y, gm.threads.z, 0, nullptr, ppargs.data(), nullptr)); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("invoke", + &invoke, + "Generic Invoke (CUDA)" + ); +} diff --git a/tutel/impls/__init__.py b/tutel/impls/__init__.py new file mode 100644 index 00000000..c45e0a75 --- /dev/null +++ b/tutel/impls/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + diff --git a/tutel/impls/fast_dispatch.py b/tutel/impls/fast_dispatch.py new file mode 100644 index 00000000..6f3147a1 --- /dev/null +++ b/tutel/impls/fast_dispatch.py @@ -0,0 +1,94 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from typing import TYPE_CHECKING, Any, Optional, Tuple, Union, cast + +import torch +from torch import Tensor + +from ..jit_kernels import sparse as jit_kernel + +class GatingEncoder(torch.autograd.Function): + @staticmethod + def forward(ctx: Any, config: Any, reshaped_input: Tensor): + ctx.reshaped_input = reshaped_input + ctx.config = config + + dispatched_input = torch.zeros([ctx.config.num_global_experts * ctx.config.capacity, ctx.config.model_dim], dtype=reshaped_input.dtype, device=reshaped_input.device) + for i in range(len(ctx.config.indices_)): + ctx.config.func_fwd(ctx.config.ones_helper, ctx.config.indices_[i], ctx.config.locations_[i], reshaped_input, dispatched_input) + return dispatched_input + + @staticmethod + def backward(ctx: Any, dispatched_input: Tensor): + last_result = None + for i in range(len(ctx.config.indices_)): + grad_data = torch.empty(ctx.reshaped_input.shape, dtype=dispatched_input.dtype, device=dispatched_input.device) + ctx.config.func_bwd_data(ctx.config.ones_helper, dispatched_input, ctx.config.indices_[i], ctx.config.locations_[i], grad_data) + last_result = grad_data if last_result is None else last_result + grad_data + return (None, last_result) + + +class GatingDecoder(torch.autograd.Function): + @staticmethod + def forward(ctx: Any, config: Any, expert_output: Tensor, *gates_: Tensor): + ctx.expert_output = expert_output + ctx.gates_h2 = [x.view(-1, 1).repeat(1, 2) if x.dtype == torch.float16 else x for x in gates_] + ctx.config = config + + last_result = None + for i in range(len(config.indices_)): + single_output = torch.empty([config.expected_sample_size, config.model_dim], dtype=expert_output.dtype, device=expert_output.device) + config.func_bwd_data(ctx.gates_h2[i], expert_output, config.indices_[i], config.locations_[i], single_output) + last_result = single_output if last_result is None else last_result + single_output + return last_result + + @staticmethod + def backward(ctx: Any, combined_output: Tensor): + grad_expert_output = torch.zeros(ctx.expert_output.shape, dtype=combined_output.dtype, device=combined_output.device) + for i in range(len(ctx.config.indices_)): + ctx.config.func_fwd(ctx.gates_h2[i], ctx.config.indices_[i], ctx.config.locations_[i], combined_output, grad_expert_output) + + grad_gates = [] + for i in range(len(ctx.config.indices_)): + grad_gates1_s = torch.empty([ctx.config.expected_sample_size,], dtype=combined_output.dtype, device=combined_output.device) + ctx.config.func_bwd_gate(ctx.expert_output, ctx.config.indices_[i], ctx.config.locations_[i], combined_output, grad_gates1_s) + grad_gates.append(grad_gates1_s) + return (None, grad_expert_output, *grad_gates) + + +class TutelMoeFastDispatcher: + + def __init__(self, num_global_experts, capacity, model_dim, dispatch_dtype): + self.expected_sample_size = -1 + self.num_global_experts = num_global_experts + self.capacity = capacity + self.model_dim = model_dim + self.kernel_pool = dict() + self.dtype = dispatch_dtype + self.aligned_dim = model_dim // (2 if self.dtype == torch.float16 else 1) + + def update(self, indices_, locations_, gates_): + self.indices_ = [x.to(torch.int32).view(-1) for x in indices_] + self.locations_ = [x.to(torch.int32) for x in locations_] + self.gates_ = [x.to(self.dtype) for x in gates_] + sample_size = self.indices_[0].size(0) + + if sample_size != self.expected_sample_size: + self.expected_sample_size = sample_size + if sample_size not in self.kernel_pool: + self.func_fwd = jit_kernel.create_forward(sample_size, self.num_global_experts, self.capacity, self.aligned_dim, self.dtype) + self.func_bwd_data = jit_kernel.create_backward_data(sample_size, self.num_global_experts, self.capacity, self.aligned_dim, self.dtype) + self.func_bwd_gate = jit_kernel.create_backward_gate(sample_size, self.num_global_experts, self.capacity, self.aligned_dim, self.dtype) + self.ones_helper = torch.ones([sample_size, 2], dtype=self.dtype, device=self.indices_[0].device) + self.kernel_pool[sample_size] = self.func_fwd, self.func_bwd_data, self.func_bwd_gate, self.ones_helper + else: + self.func_fwd, self.func_bwd_data, self.func_bwd_gate, self.ones_helper = self.kernel_pool[sample_size] + + def encode(self, data): + return GatingEncoder.apply(self, data) + + def decode(self, data): + return GatingDecoder.apply(self, data, *self.gates_) + +fast_dispatcher = TutelMoeFastDispatcher diff --git a/tutel/impls/jit_compiler.py b/tutel/impls/jit_compiler.py new file mode 100644 index 00000000..8ea8d724 --- /dev/null +++ b/tutel/impls/jit_compiler.py @@ -0,0 +1,51 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import torch +import os, tempfile + +assert torch.cuda.is_available() == True, "This version of Tutel MoE only supports CUDA. More backends will be supported soon." + +try: + import tutel_custom_kernel +except: + raise Exception("Cannot import JIT optimized kernels. Did you forget to install Custom Kernel Extension?") + +try: + from torch.utils.cpp_extension import IS_HIP_EXTENSION +except: + IS_HIP_EXTENSION = False + +try: + local_rank = int(os.environ.get('LOCAL_RANK', '0')) +except: + local_rank = 0 + +class JitCompiler: + @staticmethod + def create_raw(source): + if not hasattr(JitCompiler, '__CTX__'): + torch.cuda.init() + JitCompiler.__CTX__ = 0 + __ctx__ = JitCompiler.__CTX__ + JitCompiler.__CTX__ += 1 + + key = local_rank + temp_loc = '%s-%s.MoE' % (tempfile.mktemp(), __ctx__) + with open(temp_loc, 'w') as fp: + if IS_HIP_EXTENSION: + fp.write('#include \n#include \n') + else: + fp.write('#include \n#include \n') + fp.write(source) + os.rename(temp_loc, '/tmp/%s-%s.cu' % (__ctx__, key)) + + def func(*inputs): + tutel_custom_kernel.invoke(inputs, __ctx__ * 256 + key) + return func + + @staticmethod + def generate_kernel(keyword_dict, template): + for key in keyword_dict: + template = template.replace('@%s@' % key, str(keyword_dict[key])) + return JitCompiler.create_raw(template) diff --git a/tutel/impls/moe_layer.py b/tutel/impls/moe_layer.py new file mode 100644 index 00000000..21e0fb47 --- /dev/null +++ b/tutel/impls/moe_layer.py @@ -0,0 +1,371 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from typing import TYPE_CHECKING, Any, Optional, Tuple, Union, cast + +import torch +from torch import Tensor +import torch.distributed as dist +from torch.nn import ModuleList +import torch.nn.functional as F + +from ..impls.fast_dispatch import fast_dispatcher +from ..jit_kernels.gating import fast_cumsum_sub_one + + +def get_world_size(group): + try: + return dist.get_world_size(group) + except: + return 1 + +def get_world_rank(group): + try: + return dist.get_rank(group) + except: + return 0 + +def one_hot_with_dtype(data, num_classes, dtype): + result = torch.zeros([data.size(0), num_classes], device=data.device, dtype=dtype) + result.scatter_(1, data.unsqueeze(-1), 1) + return result + + +class AllToAll(torch.autograd.Function): + @staticmethod + def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor): + ctx.group = group + ctx.world_size = get_world_size(group) + if ctx.world_size <= 1 or AllToAll.skip_a2a: + return input + input = input.contiguous() + output = torch.empty_like(input) + dist.all_to_all_single(output, input, group=group) + return output + + @staticmethod + def backward(ctx: Any, grad_output: Tensor): + if ctx.world_size <= 1 or AllToAll.skip_a2a: + return (None, grad_output) + return (None, AllToAll.apply(ctx.group, grad_output)) + + +def load_balance(gates, mask1, num_global_experts, use_fp32): + if gates.dtype == torch.float32 or use_fp32: + me = torch.sum(gates.float(), dim=0) + ce = torch.sum(mask1.to(me.dtype), dim=0) + l_loss = torch.sum(me * ce) * (num_global_experts / (gates.size(0) * gates.size(0))) + else: + me = torch.mean(gates, dim=0) + ce = torch.mean(mask1.to(gates.dtype), dim=0) + l_loss = torch.sum(me * ce) * num_global_experts + return l_loss + + +class Top1Gate(torch.nn.Module): + + def __init__( + self, + model_dim, + num_global_experts, + capacity_factor=1.0, + use_fp32=False, + ): + super().__init__() + self.wg = torch.nn.Linear(model_dim, num_global_experts, bias=False) + self.capacity_factor = capacity_factor + self.use_fp32 = use_fp32 + self.num_global_experts = num_global_experts + + def capacity(self, expected_sample_size): + if not hasattr(self, 'capacity_int'): + self.capacity_int = int(self.capacity_factor * ((expected_sample_size + self.num_global_experts - 1) // self.num_global_experts)) + return self.capacity_int + + def forward(self, input: torch.Tensor): + logits = self.wg(input) + + indices1_s = torch.argmax(logits, dim=1) + mask1 = one_hot_with_dtype(indices1_s, num_classes=self.num_global_experts, dtype=indices1_s.dtype) + + mask1_ = mask1.to(logits.dtype) + gates = F.softmax(logits, dim=1) + gates1_s = (gates * mask1_).sum(dim=1) + l_loss = load_balance(gates, mask1_, self.num_global_experts, self.use_fp32) + + locations1 = fast_cumsum_sub_one(mask1) + locations1_s = torch.sum(locations1 * mask1, dim=1).to(torch.int32) + + return l_loss, [gates1_s, ], [indices1_s.to(torch.int32), ], [locations1_s.to(torch.int32), ] + + +class Top2Gate(torch.nn.Module): + + def __init__( + self, + model_dim, + num_global_experts, + capacity_factor=1.0, + use_fp32=False, + ): + super().__init__() + self.wg = torch.nn.Linear(model_dim, num_global_experts, bias=False) + self.capacity_factor = capacity_factor + self.use_fp32 = use_fp32 + self.num_global_experts = num_global_experts + assert self.num_global_experts >= 2, "You have only 1 expert, while you are using a top-2 gate." + + def capacity(self, expected_sample_size): + if not hasattr(self, 'capacity_int'): + self.capacity_int = 2 * int(self.capacity_factor * ((expected_sample_size + self.num_global_experts - 1) // self.num_global_experts)) + return self.capacity_int + + def forward(self, input: torch.Tensor): + logits = self.wg(input) + + top2_indices = torch.topk(logits, 2, dim=1).indices + indices1_s, indices2_s = top2_indices.chunk(2, dim=1) + indices1_s, indices2_s = indices1_s.view(-1), indices2_s.view(-1) + + mask1 = one_hot_with_dtype(indices1_s, num_classes=self.num_global_experts, dtype=indices1_s.dtype) + mask2 = one_hot_with_dtype(indices2_s, num_classes=self.num_global_experts, dtype=indices2_s.dtype) + + gates = F.softmax(logits, dim=1) + gates1_s = (gates * mask1).sum(dim=1) + gates2_s = (gates * mask2).sum(dim=1) + l_loss = load_balance(gates, mask1, self.num_global_experts, self.use_fp32) + + locations1 = fast_cumsum_sub_one(mask1) + locations1_s = torch.sum(locations1 * mask1, dim=1).to(torch.int32) + + locations2 = fast_cumsum_sub_one(mask2) + locations2 += torch.sum(mask1, dim=0, keepdim=True) + locations2_s = torch.sum(locations2 * mask2, dim=1) + + # Normalize Gate + denom_s = torch.clamp(gates1_s + gates2_s, min=torch.finfo(gates2_s.dtype).eps) + gates1_s /= denom_s + gates2_s /= denom_s + + return l_loss, [gates1_s, gates2_s], [indices1_s.to(torch.int32), indices2_s.to(torch.int32)], [locations1_s.to(torch.int32), locations2_s.to(torch.int32)] + + +class MOELayer(torch.nn.Module): + """Tutel optimized MOELayer + + e.g. + + moe_layer = MOELayer('Top2Gate', model_dim, experts={'type': 'ffn', 'hidden_size_per_expert': 1024}) + y = moe_layer(x) + + Args: + gate : the string type of MOE gate, e.g: Top1Gate, Top2Gate + model_dim : the number of channels for MOE's input tensor + experts : a dict-type config for builtin expert network, or a torch.nn.Module-type custom expert network + fp32_gate : option of enabling mixed precision for gate network + scan_expert_func : allow users to specify a lambda function to iterate each experts param, e.g. `scan_expert_func = lambda name, param: setattr(param, 'expert', True)` + result_func : allow users to specify a lambda function to format the MoE output and aux_loss, e.g. `result_func = lambda output: (output, output.l_aux)` + group : specify the explicit communication group of all_to_all + seeds : a tuple containing a pair of int to specify manual seed of (shared params, local params) + """ + + def __init__(self, gate_type, model_dim: int, experts = None, fp32_gate = False, scan_expert_func = None, result_func = None, group: Optional[Any] = None, seeds = None): + super().__init__() + + assert model_dim % 2 == 0, "Model_dim (%s) must be even value, while this Model_dim mod 2 > 0." % model_dim + self.expert_group = group = group if group is not None else dist.group.WORLD + self.world_size = get_world_size(self.expert_group) + self.result_func = result_func + + import os + self.skip_moe = (int(os.environ.get('SKIP_MOE', '0')) != 0) + AllToAll.skip_a2a = (int(os.environ.get('SKIP_A2A', '0')) != 0) + + if not isinstance(experts, dict): + self.experts = cast(ModuleList, experts) if type(experts) == ModuleList else ModuleList(experts) + self.num_local_experts = len(self.experts) + else: + network_type = experts['type'] + if network_type == 'ffn': + ''' << Fused FFN Experts V1 >> (kernels = 5) + + hidden[W, E, C, V] +=! input[W, E, C, M] x expert_fc1[0, E, M, V] + hidden[W, E, C, V] = hidden[W, E, C, V] + bias_fc1[E, V] + hidden[W, E, C, V] = activation_fn(hidden[W, E, C, V]) + hidden[W, E, C, M] +=! hidden[W, E, C, V] x expert_fc2[0, E, V, M] + output[W, E, C, M] = hidden[W, E, C, M] + bias_fc2[E, M] + + << Fused FFN Experts V2 >> (kernels = 7) + + hidden[E, W, C, M] = input[W, E, C, M] + hidden[E, W, C, V] +=! hidden[E, W, C, M] x expert_fc1[0, E, M, V] + hidden[E, W, C, V] = hidden[E, W, C, V] + bias_fc1[E, V] + hidden[E, W, C, V] = activation_fn(hidden[E, W, C, V]) + hidden[E, W, C, M] +=! hidden[E, W, C, V] x expert_fc2[0, E, V, M] + hidden[E, W, C, M] = hidden[E, W, C, M] + bias_fc2[E, M] + output[E, W, C, M] = hidden[E, W, C, M] + ''' + + self.num_local_experts = experts.get('count_per_node', 1) + fused_custom_fn = experts.get('fused_custom_fn') + if fused_custom_fn is None: + activation_fn = experts.get('activation_fn', lambda x: F.relu(x)) + + class FusedExpertsNetwork(torch.nn.Module): + def __init__(self, model_dim, hidden_size, local_experts): + super().__init__() + self.skip_expert = (int(os.environ.get('SKIP_EXPERT', '0')) != 0) + + fc1_weight = torch.empty(1, local_experts, model_dim, hidden_size) + fc2_weight = torch.empty(1, local_experts, hidden_size, model_dim) + fc1_bias = torch.empty(1, local_experts, 1, hidden_size) + fc2_bias = torch.empty(1, local_experts, 1, model_dim) + + for i in range(local_experts): + fc1 = torch.nn.Linear(model_dim, hidden_size) + fc2 = torch.nn.Linear(hidden_size, model_dim) + fc1_weight[0, i, :, :], fc1_bias[0, i, :, :] = fc1.weight.t(), fc1.bias + fc2_weight[0, i, :, :], fc2_bias[0, i, :, :] = fc2.weight.t(), fc2.bias + + self.model_dim, self.hidden_size, self.local_experts = model_dim, hidden_size, local_experts + + if self.local_experts == 1: + fc1_weight = fc1_weight.view(self.model_dim, self.hidden_size) + fc2_weight = fc2_weight.view(self.hidden_size, self.model_dim) + fc1_bias = fc1_bias.view(-1, self.hidden_size) + fc2_bias = fc2_bias.view(-1, self.model_dim) + else: + fc1_weight = fc1_weight.view(self.local_experts, self.model_dim, self.hidden_size) + fc2_weight = fc2_weight.view(self.local_experts, self.hidden_size, self.model_dim) + fc1_bias = fc1_bias.view(self.local_experts, 1, self.hidden_size) + fc2_bias = fc2_bias.view(self.local_experts, 1, self.model_dim) + + self.register_parameter(name='fc1_weight', param=torch.nn.Parameter(fc1_weight)) + self.register_parameter(name='fc2_weight', param=torch.nn.Parameter(fc2_weight)) + self.register_parameter(name='fc1_bias', param=torch.nn.Parameter(fc1_bias)) + self.register_parameter(name='fc2_bias', param=torch.nn.Parameter(fc2_bias)) + + def extra_repr(self): + return 'model_dim=%d, hidden_size=%d, local_experts=%d' % (self.model_dim, self.hidden_size, self.local_experts) + + def forward(self, x): + if self.skip_expert: + return x + if fused_custom_fn is not None: + x = fused_custom_fn(self, x) + elif self.local_experts == 1: + original_shape, x = x.shape, x.view(-1, self.model_dim) + x = torch.addmm(self.fc1_bias, x, self.fc1_weight) + x = activation_fn(x) + x = torch.addmm(self.fc2_bias, x, self.fc2_weight) + x = x.view(original_shape) + else: + x = x.permute(1, 0, 2, 3) + original_shape, x = x.shape, x.reshape(self.local_experts, -1, self.model_dim) + x = torch.matmul(x, self.fc1_weight) + self.fc1_bias + x = activation_fn(x) + x = torch.matmul(x, self.fc2_weight) + self.fc2_bias + x = x.reshape(self.local_experts, original_shape[1], original_shape[2], self.model_dim) + x = x.permute(1, 0, 2, 3) + return x + + def to(self, *args, **kwargs): + self = super().to(*args, **kwargs) + self.fc1_weight = self.fc1_weight.to(*args, **kwargs) + self.fc2_weight = self.fc2_weight.to(*args, **kwargs) + self.fc1_bias = self.fc1_bias.to(*args, **kwargs) + self.fc2_bias = self.fc2_bias.to(*args, **kwargs) + return self + + if seeds is not None: + torch.manual_seed(seeds[1]) + self.experts = ModuleList([FusedExpertsNetwork(model_dim, experts['hidden_size_per_expert'], self.num_local_experts)]) + else: + raise Exception('Builtin expert type is not recognized: %s' % network_type) + + if scan_expert_func is not None: + for expert in self.experts: + for n, p in expert.named_parameters(): + scan_expert_func(n, p) + + self.num_global_experts = self.world_size * self.num_local_experts + self.model_dim = model_dim + + if gate_type == 'Top1Gate' or (gate_type == 'Top2Gate' and self.num_global_experts == 1): + gating = Top1Gate + elif gate_type == 'Top2Gate': + gating = Top2Gate + else: + raise Exception("Unrecognized gate_type: %s" % gate_type) + + if seeds is not None: + torch.manual_seed(seeds[0]) + self.gate = gating(model_dim=model_dim, num_global_experts=self.num_global_experts, use_fp32=fp32_gate) + + def get_parameter_iterator(self, param_type): + if param_type == 'gate': + return self.gate.named_parameters() + elif param_type == 'local_experts': + return self.experts.named_parameters() + else: + raise Exception("Specified parameter type is not recognized: %s. Valid `param_type` includes: gate, local_experts." % param_type) + + def forward(self, input: Tensor, **kwargs: Any): + if self.skip_moe: + result_output = input + result_output.l_aux = None + return self.result_func(result_output) if self.result_func is not None else result_output + + original_shape, original_dtype = input.shape, input.dtype + assert len(input.shape) >= 2, "Input data must be at least 2D tensor: (s)amples, .., (m)odel_dim" + reshaped_input = input.reshape(-1, input.shape[-1]) + reshaped_input_samples = reshaped_input.shape[0] + + self.expected_sample_size = getattr(self, 'expected_sample_size', 0) or reshaped_input.size(0) + if reshaped_input.size(0) != self.expected_sample_size: + if reshaped_input.size(0) > self.expected_sample_size: + raise Exception('MoE JIT is designed to work on sample size = %s, while receiving sample size = %s (> %s)' % (self.expected_sample_size, reshaped_input.size(0), self.expected_sample_size)) + else: + print('MoE is initialized to keep working on sample size = %s, while receiving sample size = %s (will slow down this forward step)' % (self.expected_sample_size, reshaped_input.size(0))) + pad_input = torch.zeros([self.expected_sample_size, self.model_dim], dtype=reshaped_input.dtype, layout=reshaped_input.layout, device=reshaped_input.device) + pad_input[:reshaped_input.size(0)] = reshaped_input + reshaped_input = pad_input + + if not hasattr(self, 'param_dtype'): + self.param_dtype = next(iter(self.experts.parameters())).dtype + self.capacity = self.gate.capacity(self.expected_sample_size) + + reshaped_input = reshaped_input.to(self.param_dtype) + l_aux, gates_, indices_, locations_ = self.gate(reshaped_input) + + if not hasattr(self, '_tutel_dispatcher'): + self._tutel_dispatcher = fast_dispatcher(num_global_experts=self.num_global_experts, capacity=self.capacity, model_dim=self.model_dim, dispatch_dtype=reshaped_input.dtype) + + self._tutel_dispatcher.update(indices_, locations_, gates_) + + S, M, GE, C = self.expected_sample_size, self.model_dim, self.num_global_experts, self.capacity + + dispatched_input = self._tutel_dispatcher.encode(reshaped_input) + dispatched_input = AllToAll.apply(self.expert_group, dispatched_input) + + dispatched_input = dispatched_input.reshape(self.world_size, self.num_local_experts, -1, M) + + if len(self.experts) == 1: + expert_output = self.experts[0](dispatched_input) + else: + chunks = dispatched_input.chunk(self.num_local_experts, dim=1) + expert_outputs = [expert(chunk) for chunk, expert in zip(chunks, self.experts)] + expert_output = torch.cat(expert_outputs, dim=1) + + expert_output = AllToAll.apply(self.expert_group, expert_output) + + expert_output = expert_output.reshape(self.world_size * self.num_local_experts, -1, M) + + result_output = self._tutel_dispatcher.decode(expert_output.view(GE * C, M)) + + result_output = result_output[:reshaped_input_samples, :] + result_output = result_output.view(original_shape).to(original_dtype) + self.l_aux = result_output.l_aux = l_aux + return self.result_func(result_output) if self.result_func is not None else result_output + +moe_layer = MOELayer diff --git a/tutel/jit_kernels/__init__.py b/tutel/jit_kernels/__init__.py new file mode 100644 index 00000000..c45e0a75 --- /dev/null +++ b/tutel/jit_kernels/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + diff --git a/tutel/jit_kernels/gating.py b/tutel/jit_kernels/gating.py new file mode 100644 index 00000000..893147ac --- /dev/null +++ b/tutel/jit_kernels/gating.py @@ -0,0 +1,79 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import os +import torch +from ..impls.jit_compiler import JitCompiler + + +disable_gate_opt = int(os.environ.get('GATE', '1')) == 0 +cumsum_kernels = dict() + +def get_cumsum_kernel(samples, global_experts): + + if disable_gate_opt: + print('[WARN]', "Optimized cumsum is disabled, and may result in big performance regression.") + + def torch_cumsum(mask1): + locations1 = torch.cumsum(mask1, dim=0) - 1 + return locations1 + return torch_cumsum + + global cumsum_kernels + if samples in cumsum_kernels: + return cumsum_kernels[samples] + + base_kernel = JitCompiler.generate_kernel({'batch_num': global_experts, 'num_samples': samples}, ''' + #define thread_num 1024 + #define batch_num (@batch_num@) + + extern "C" __global__ void cumsum(int* input0 /* (num_samples, batch_num) */, int* output0 /* (num_samples, batch_num) */) { + // [thread_extent] blockIdx.x = @batch_num@ + // [thread_extent] threadIdx.x = 1024 + __shared__ int temp[thread_num + 1]; + int thid = threadIdx.x, bid = blockIdx.x; + int last_sum = -1; + + for (int S = 0; S < @num_samples@; S += thread_num, output0 += thread_num * batch_num, input0 += thread_num * batch_num) { + int offset = 1; + if (S + thid < @num_samples@) + temp[thid] = input0[thid * batch_num + bid]; + for (int d = thread_num >> 1; d > 0; d >>= 1) { + __syncthreads(); + if (thid < d) + temp[offset * (2 * thid + 2) - 1] += temp[offset * (2 * thid + 1) - 1]; + offset *= 2; + } + if (thid == 0) + temp[thread_num] = temp[thread_num - 1], temp[thread_num - 1] = 0; + for (int d = 1; d < thread_num; d *= 2) { + offset >>= 1; + __syncthreads(); + if (thid < d) { + int ai = offset * (2 * thid + 1) - 1; + int bi = offset * (2 * thid + 2) - 1; + int t = temp[ai]; + temp[ai] = temp[bi]; + temp[bi] += t; + } + } + __syncthreads(); + if (S + thid < @num_samples@) + output0[thid * batch_num + bid] = temp[thid + 1] + last_sum; + last_sum += temp[thread_num]; + } + } + ''') + + def optimized_cumsum(mask1): + locations1 = torch.empty(mask1.shape, dtype=torch.int32, device=mask1.device).contiguous() + base_kernel(mask1.to(torch.int32), locations1) + return locations1 + + cumsum_kernels[samples] = optimized_cumsum + return optimized_cumsum + +def fast_cumsum_sub_one(data, dim=0): + if data.dim() != 2 or dim != 0: + raise Exception("Unimplemented fast_cumsum_sub_one() of data = %s and dim = %s" % (data.size(), dim)) + return get_cumsum_kernel(data.size(0), data.size(1))(data) diff --git a/tutel/jit_kernels/sparse.py b/tutel/jit_kernels/sparse.py new file mode 100644 index 00000000..ba3ddbc1 --- /dev/null +++ b/tutel/jit_kernels/sparse.py @@ -0,0 +1,132 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import torch +from ..impls.jit_compiler import JitCompiler + + +def get_kernel_dtype(param_dtype): + if param_dtype == torch.float16: + return '__half2' + elif param_dtype == torch.float32: + return 'float' + else: + raise Exception("Unrecognized data type: %s" % param_dtype) + + +def create_forward(samples, global_experts, capacity, aligned_dim, param_dtype): + return JitCompiler.generate_kernel({'capacity': capacity, 'samples': samples, 'hidden': aligned_dim, 'dtype': get_kernel_dtype(param_dtype), 'IS_FLOAT': 1 if param_dtype == torch.float32 else 0}, ''' + #define capacity (@capacity@) + #define samples (@samples@) + #define hidden (@hidden@) + #define __dtype @dtype@ + + extern "C" __global__ __launch_bounds__(1024) void execute(__dtype* __restrict__ gates1_s, int* __restrict__ indices1_s, int* __restrict__ locations1_s, __dtype* __restrict__ reshaped_input, __dtype* __restrict__ dispatched_input) { + // [thread_extent] blockIdx.x = 128 + // [thread_extent] threadIdx.x = 1024 + + for (int i = blockIdx.x; i < samples; i += gridDim.x) + if (locations1_s[i] < capacity) { + #pragma unroll + for (int j = threadIdx.x; j < hidden; j += 1024) + atomicAdd(&dispatched_input[(indices1_s[i] * capacity + locations1_s[i]) * (hidden) + j], gates1_s[i] * reshaped_input[i * (hidden) + j]); + } + } + ''') + + +def create_backward_data(samples, global_experts, capacity, aligned_dim, param_dtype): + return JitCompiler.generate_kernel({'capacity': capacity, 'samples': samples, 'hidden': aligned_dim, 'dtype': get_kernel_dtype(param_dtype), 'IS_FLOAT': 1 if param_dtype == torch.float32 else 0}, ''' + #define capacity (@capacity@) + #define samples (@samples@) + #define hidden (@hidden@) + #define __dtype @dtype@ + + extern "C" __global__ __launch_bounds__(1024) void execute(__dtype* __restrict__ gates1_s, __dtype* __restrict__ dispatched_input, int* __restrict__ indices1_s, int* __restrict__ locations1_s, __dtype* __restrict__ grad_reshaped_input) { + // [thread_extent] blockIdx.x = 128 + // [thread_extent] threadIdx.x = 1024 + + for (int i = blockIdx.x; i < samples; i += gridDim.x) + if (locations1_s[i] < capacity) { + #pragma unroll + for (int j = threadIdx.x; j < hidden; j += 1024) + grad_reshaped_input[i * hidden + j] = gates1_s[i] * dispatched_input[(indices1_s[i] * capacity + locations1_s[i]) * (hidden) + j]; + } else { + #pragma unroll + for (int j = threadIdx.x; j < hidden; j += 1024) + #if @IS_FLOAT@ + grad_reshaped_input[i * hidden + j] = __dtype(0); + #else + grad_reshaped_input[i * hidden + j] = __dtype(0, 0); + #endif + } + } + ''') + + +def create_backward_gate(samples, global_experts, capacity, aligned_dim, param_dtype): + return JitCompiler.generate_kernel({'capacity': capacity, 'samples': samples, 'hidden': aligned_dim, 'dtype': get_kernel_dtype(param_dtype), 'IS_FLOAT': 1 if param_dtype == torch.float32 else 0}, ''' + #define capacity (@capacity@) + #define samples (@samples@) + #define hidden (@hidden@) + #define __dtype @dtype@ + + extern "C" __global__ __launch_bounds__(32) void execute(__dtype* __restrict__ dispatched_input, int* __restrict__ indices1_s, int* __restrict__ locations1_s, __dtype* __restrict__ reshaped_input, void* __restrict__ grad_gates1_s) { + // [thread_extent] blockIdx.x = @samples@ + // [thread_extent] threadIdx.x = 32 + if (locations1_s[blockIdx.x] >= capacity) { + if (((int)threadIdx.x) == 0) + #if @IS_FLOAT@ + ((float*)grad_gates1_s)[(((int)blockIdx.x))] = 0; + #else + ((half*)grad_gates1_s)[(((int)blockIdx.x))] = __float2half_rn(0.000000e+00f); + #endif + return; + } + int indice = indices1_s[(int)blockIdx.x] * capacity + locations1_s[(int)blockIdx.x]; + #if @IS_FLOAT@ + __dtype grad_gates1_s_rf = 0.000000e+00f; + #else + __dtype grad_gates1_s_rf = __dtype(0, 0); + #endif + for (int i = threadIdx.x; i < hidden; i += 32) + grad_gates1_s_rf += dispatched_input[indice * (hidden) + i] * reshaped_input[((int)blockIdx.x) * (hidden) + i]; + +#if !defined(__HIPCC__) + __dtype red_buf0[1]; + uint mask[1]; + __dtype t0[1]; + red_buf0[(0)] = grad_gates1_s_rf; + mask[(0)] = __activemask(); + t0[(0)] = __shfl_down_sync(mask[(0)], red_buf0[(0)], 16, 32); + red_buf0[(0)] = (red_buf0[(0)] + t0[(0)]); + t0[(0)] = __shfl_down_sync(mask[(0)], red_buf0[(0)], 8, 32); + red_buf0[(0)] = (red_buf0[(0)] + t0[(0)]); + t0[(0)] = __shfl_down_sync(mask[(0)], red_buf0[(0)], 4, 32); + red_buf0[(0)] = (red_buf0[(0)] + t0[(0)]); + t0[(0)] = __shfl_down_sync(mask[(0)], red_buf0[(0)], 2, 32); + red_buf0[(0)] = (red_buf0[(0)] + t0[(0)]); + t0[(0)] = __shfl_down_sync(mask[(0)], red_buf0[(0)], 1, 32); + red_buf0[(0)] = (red_buf0[(0)] + t0[(0)]); + red_buf0[(0)] = __shfl_sync(mask[(0)], red_buf0[(0)], 0, 32); +#else + __shared__ __dtype red_buf0[32]; + __syncthreads(); + ((volatile __dtype*)red_buf0)[(((int)threadIdx.x))] = grad_gates1_s_rf; + if (((int)threadIdx.x) < 16) { + ((volatile __dtype*)red_buf0)[(((int)threadIdx.x))] = ((__dtype)(((volatile __dtype*)red_buf0)[(((int)threadIdx.x))]) + (__dtype)(((volatile __dtype*)red_buf0)[((((int)threadIdx.x) + 16))])); + ((volatile __dtype*)red_buf0)[(((int)threadIdx.x))] = ((__dtype)(((volatile __dtype*)red_buf0)[(((int)threadIdx.x))]) + (__dtype)(((volatile __dtype*)red_buf0)[((((int)threadIdx.x) + 8))])); + ((volatile __dtype*)red_buf0)[(((int)threadIdx.x))] = ((__dtype)(((volatile __dtype*)red_buf0)[(((int)threadIdx.x))]) + (__dtype)(((volatile __dtype*)red_buf0)[((((int)threadIdx.x) + 4))])); + ((volatile __dtype*)red_buf0)[(((int)threadIdx.x))] = ((__dtype)(((volatile __dtype*)red_buf0)[(((int)threadIdx.x))]) + (__dtype)(((volatile __dtype*)red_buf0)[((((int)threadIdx.x) + 2))])); + ((volatile __dtype*)red_buf0)[(((int)threadIdx.x))] = ((__dtype)(((volatile __dtype*)red_buf0)[(((int)threadIdx.x))]) + (__dtype)(((volatile __dtype*)red_buf0)[((((int)threadIdx.x) + 1))])); + } + __syncthreads(); +#endif + if (((int)threadIdx.x) == 0) + #if @IS_FLOAT@ + ((float*)grad_gates1_s)[(((int)blockIdx.x))] = red_buf0[(0)]; + #else + ((half*)grad_gates1_s)[(((int)blockIdx.x))] = red_buf0[(0)].x + red_buf0[(0)].y; + #endif + } + ''') diff --git a/tutel/moe.py b/tutel/moe.py new file mode 100644 index 00000000..ffd0c0d4 --- /dev/null +++ b/tutel/moe.py @@ -0,0 +1,12 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +High-level interface available for users: + +""" + +from .jit_kernels.gating import fast_cumsum_sub_one +from .impls.fast_dispatch import fast_dispatcher +from .impls.moe_layer import moe_layer + diff --git a/tutel/test/eval_bgemm.py b/tutel/test/eval_bgemm.py new file mode 100755 index 00000000..b725288d --- /dev/null +++ b/tutel/test/eval_bgemm.py @@ -0,0 +1,95 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import time +import torch +import torch.distributed as dist +import argparse + +parser = argparse.ArgumentParser() + +parser.add_argument('--fp16', default=False, action='store_true') +parser.add_argument('--w', type=int, default=1) +parser.add_argument('--e', type=int, default=2) +parser.add_argument('--m', type=int, default=2048) +parser.add_argument('--h', type=int, default=2048) + +args = parser.parse_args() + +assert args.h % args.e == 0 +args.h //= args.e + +device = torch.device('cuda', 0) +default_dtype = torch.float16 if args.fp16 else torch.float32 + +default_dtype = torch.float16 if args.fp16 else torch.float32 +X = torch.randn([args.w, args.e, args.m, args.m], dtype=default_dtype, device=device) +Y = torch.randn([1, args.e, args.m, args.h], dtype=default_dtype, device=device) +Z = torch.randn([args.w, args.e, args.m, args.h], dtype=default_dtype, device=device) + +print('X = %s, Y = %s => Z = %s' % (X.size(), Y.size(), Z.size())) + +def evaluate(func_name): + func = eval(func_name) + average_time, num_steps = 0, 30 + for i in range(num_steps): + torch.cuda.synchronize() + t_start = time.time() + func() + torch.cuda.synchronize() + t_stop = time.time() + if i + 10 >= num_steps: + average_time += t_stop - t_start + average_time /= 10 + tflops = (2.0 * args.w * args.e * args.h * args.m * args.m) / average_time * 1e-12 + print('\n[Summary] Average synchronized step_time of `%s:%s` = %s sec. (Tflops = %s)' % ( + func_name, default_dtype, average_time, tflops)) + return average_time + +X_l = torch.randn([args.w * args.e * args.m, args.m], dtype=default_dtype, device=device) +Y_l = torch.randn([args.m, args.h], dtype=default_dtype, device=device) +def layout_sgemm(): + torch.matmul(X_l, Y_l) + +def auto_broadcast_bgemm(): + torch.matmul(X, Y) + +def manual_broadcast_bgemm(): + torch.matmul(X, Y.repeat(X.size(0), 1, 1, 1)) + +Y_one = Y.repeat(X.size(0), 1, 1, 1).contiguous() +def skip_broadcast_bgemm(): + torch.matmul(X, Y) + +X_one = X[0, :].contiguous() +def world_bgemm(): + for i in range(X.size(0)): + torch.matmul(X_one, Y) + +X_two, Y_two = X[:, 0, :].contiguous(), Y[:, 0, :].contiguous() +def expert_bgemm(): + for i in range(X.size(1)): + torch.matmul(X_two, Y_two) + +def backward_reduce_no_sum(): + torch.matmul(X, Z) + +def backward_reduce(): + middle = torch.matmul(X, Z) + torch.sum(middle, dim=0) + +def test(): + evaluate('auto_broadcast_bgemm') + evaluate('manual_broadcast_bgemm') + evaluate('skip_broadcast_bgemm') + evaluate('world_bgemm') + evaluate('expert_bgemm') + evaluate('layout_sgemm') + evaluate('backward_reduce_no_sum') + evaluate('backward_reduce') + +if __name__ == '__main__': + test() +