Skip to content

Commit

Permalink
initial tutel version
Browse files Browse the repository at this point in the history
Signed-off-by: Wei CUI <[email protected]>
  • Loading branch information
ghostplant committed Sep 14, 2021
1 parent 5f30e02 commit bb316e4
Show file tree
Hide file tree
Showing 15 changed files with 1,207 additions and 8 deletions.
46 changes: 38 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,44 @@
# Project
# Project Tutel

> This repo has been populated by an initial template to help get you started. Please
> make sure to update the content to build a great experience for community-building.
Tutel MoE: An Optimized Mixture-of-Experts Implementation.

As the maintainer of this project, please make a few updates:
- Supported Framework: Pytorch
- Supported GPUs: CUDA(fp32 + fp16), ROCm(fp32)

- Improving this README.MD file to provide a great experience
- Updating SUPPORT.MD with content about this project's support experience
- Understanding the security reporting process in SECURITY.MD
- Remove this section from the README
How to setup Tutel MoE for Pytorch:
```
* Install Online:
$ python3 -m pip install --user https://github.com/microsoft/tutel/releases/download/v0.1.0/tutel-0.1.0.tar.gz
* Build from Source:
$ git clone https://github.com/microsoft/tutel
$ python3 ./tutel/setup.py install --user
```

How to use Tutel-optimized MoE in Pytorch:
```
* Tutel MoE Example:
moe_layer = MOELayer('Top2Gate', model_dim, experts={'type': 'ffn', 'hidden_size_per_expert': 1024})
y = moe_layer(x)
* Usage of MOELayer Args:
gate : the string type of MOE gate, e.g: Top1Gate, Top2Gate
model_dim : the number of channels for MOE's input tensor
experts : a dict-type config for builtin expert network, or a torch.nn.Module-type custom expert network
fp32_gate : option of enabling mixed precision for gate network
scan_expert_func : allow users to specify a lambda function to iterate each experts param, e.g. `scan_expert_func = lambda name, param: setattr(param, 'expert', True)`
result_func : allow users to specify a lambda function to format the MoE output and aux_loss, e.g. `result_func = lambda output: (output, output.l_aux)`
group : specify the explicit communication group of all_to_all
seeds : a tuple containing a pair of int to specify manual seed of (shared params, local params)
* Running MoE Hello World Model:
$ python3 -m torch.distributed.launch --nproc_per_node=1 ./examples/helloworld.py
```

## Contributing

Expand Down
125 changes: 125 additions & 0 deletions examples/helloworld.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
#!/usr/bin/env python3

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import time
import torch
import torch.optim as optim
import torch.nn.functional as F
import torch.distributed as dist
from torch import nn
import argparse

from tutel import moe as tutel_moe

parser = argparse.ArgumentParser()

parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument('--batch_size', type=int, default=4)
parser.add_argument('--num_tokens', type=int, default=512)
parser.add_argument('--model_dim', type=int, default=2048)
parser.add_argument('--hidden_size', type=int, default=1024)
parser.add_argument('--num_local_experts', type=int, default=2)
parser.add_argument('--dtype', type=str, default='float32')
parser.add_argument('--fp32_gate', default=False, action='store_true')
parser.add_argument('--top', type=int, default=2)
args = parser.parse_args()

torch.cuda.set_device(args.local_rank)

try:
if dist.is_available():
dist.init_process_group('nccl')
dist_rank = dist.get_rank()
dist_world_size = dist.get_world_size()

def dist_print(*args):
if dist_rank == 0:
print(*args)
except:
dist_rank = 0
dist_world_size = 1
dist_print = print

batch_size = args.batch_size
num_tokens = args.num_tokens
model_dim = args.model_dim
hidden_size = args.hidden_size
num_local_experts = args.num_local_experts
top_value = args.top
local_rank = args.local_rank


device = torch.device('cuda', args.local_rank)

if args.dtype == 'float32':
torch.set_default_dtype(torch.float32)
elif args.dtype == 'float16':
torch.set_default_dtype(torch.float16)
else:
raise Exception('Unrecognized data type specified: %s' % args.dtype)


class ExampleModel(torch.nn.Module):
def __init__(self):
super().__init__()

self._moe_layer = tutel_moe.moe_layer(
gate_type = 'Top%dGate' % top_value,
model_dim = model_dim,
experts = {'type': 'ffn', 'count_per_node': num_local_experts, 'hidden_size_per_expert': hidden_size},
fp32_gate = args.fp32_gate,
scan_expert_func = lambda name, param: setattr(param, 'expert', True),
seeds = (1, dist_rank + 1),
).to(device)

# Distinguish different parameter types: gate, local_experts
local_count = sum([torch.numel(param) for name, param in self._moe_layer.get_parameter_iterator(param_type='local_experts')])
shared_count = sum([torch.numel(param) for name, param in self._moe_layer.get_parameter_iterator(param_type='gate')])
dist_print('[Statistics] param count for MoE local_experts = %s, param count for MoE gate = %s.\n' % (local_count, shared_count))

def forward(self, input):
result = self._moe_layer(input)
result = F.log_softmax(torch.sum(result, dim=2), dim=1)
return result

model = ExampleModel()
dist_print(model)

optimizer = torch.optim.SGD(model.parameters(), lr=1e-5)

x = torch.randn([batch_size, num_tokens, model_dim], device=device, requires_grad=True)
y = torch.LongTensor(batch_size).random_(1).to(device)

tuples = (dist_world_size, args.dtype, model_dim, hidden_size, batch_size * num_tokens, num_local_experts, top_value, device)
dist_print('[Benchmark] world_size = %s, dtype = %s, model_dim = %s, hidden_size = %s, samples = %s, num_local_experts = %s, topK = %s, device = `%s`' % tuples)

average_time, num_steps = 0, 100

params_for_all_reduce = [p for p in model.parameters() if not hasattr(p, 'expert') and getattr(p, 'requires_grad', False)]

for i in range(num_steps):

torch.cuda.synchronize()
t_start = time.time()
optimizer.zero_grad()

output = model(x)
loss = F.nll_loss(output, y)
loss.backward()
if dist_world_size > 1:
for p in params_for_all_reduce:
p.grad /= dist_world_size
dist.all_reduce(p.grad)
optimizer.step()

torch.cuda.synchronize()
t_stop = time.time()
dist_print('STEP-%s: DONE, loss = %s, step_time = %s sec.' % (i, float(loss.data), t_stop - t_start))

if i + 10 >= num_steps:
average_time += t_stop - t_start

average_time /= 10
dist_print('\n[Summary] Average synchronized step_time = %s sec.' % average_time)
76 changes: 76 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#!/usr/bin/env python3
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""The setuptools based setup module.
Reference:
https://packaging.python.org/guides/distributing-packages-using-setuptools/
"""

import os, sys

from setuptools import setup, find_packages
from torch.utils.cpp_extension import BuildExtension, CUDAExtension

try:
from torch.utils.cpp_extension import IS_HIP_EXTENSION
except:
IS_HIP_EXTENSION = False

if len(sys.argv) <= 1:
sys.argv += ['install', '--user']

root_path = os.path.dirname(sys.argv[0])
root_path = root_path if root_path else '.'

os.chdir(root_path)

setup(
name='tutel',
version='0.1.0',
description='An Optimized Mixture-of-Experts Implementation.',
url='https://github.com/microsoft/Tutel',
author='Microsoft',
author_email='[email protected]',
license='MIT',
classifiers=[
'Development Status :: 2 - Pre-Alpha',
'Environment :: GPU',
'Intended Audience :: Developers',
'Intended Audience :: Education',
'Intended Audience :: Science/Research',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3 :: Only',
'Programming Language :: Python :: 3.5',
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
'Topic :: Scientific/Engineering',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'Topic :: Software Development',
'Topic :: Software Development :: Libraries',
'Topic :: Software Development :: Libraries :: Python Modules',
],
keywords=['Mixture of Experts', 'MoE', 'Optimization'],
packages=find_packages(),
python_requires='>=3.6, <4',
install_requires=[
],
ext_modules=[
CUDAExtension('tutel_custom_kernel', [
'./tutel/custom/custom_kernel.cpp',
],
library_dirs=['/usr/local/cuda/lib64/stubs'],
libraries=['dl', 'cuda'] if not IS_HIP_EXTENSION else [])
],
cmdclass={
'build_ext': BuildExtension
},
project_urls={
'Source': 'https://github.com/microsoft/Tutel',
'Tracker': 'https://github.com/microsoft/Tutel/issues',
},
)
3 changes: 3 additions & 0 deletions tutel/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

3 changes: 3 additions & 0 deletions tutel/custom/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

122 changes: 122 additions & 0 deletions tutel/custom/custom_kernel.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>

#include <vector>

#include <dlfcn.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda.h>

#undef CHECK_EQ
#undef CHECK_NE
#undef CHECK_CUDA
#undef CHECK_CONTIGUOUS

#define CHECK_EQ(x, y) AT_ASSERTM((x) == (y), "CHECK_EQ fails.")
#define CHECK_NE(x, y) AT_ASSERTM((x) != (y), "CHECK_NE fails.")
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")


static void invoke(const std::vector<torch::Tensor> &ts, int _key) {
struct ModuleConfig {
CUmodule hMod = nullptr;
CUfunction hFunc = nullptr;

dim3 blocks, threads;
};

static std::vector<ModuleConfig> gpuMods;

#if !defined(__HIP_PLATFORM_HCC__)
#if 0
static void *libcuda = nullptr;
static int (*cuModuleLoad)(...) = nullptr;
static int (*cuModuleGetFunction)(...) = nullptr;
static int (*cuLaunchKernel)(...) = nullptr;

if (libcuda == nullptr) {
(libcuda == nullptr ? (libcuda = dlopen("/usr/lib/x86_64-linux-gnu/libcuda.so.1", RTLD_LAZY | RTLD_GLOBAL)) : 0);
(libcuda == nullptr ? (libcuda = dlopen("/usr/lib/x86_64-linux-gnu/libcuda.so", RTLD_LAZY | RTLD_GLOBAL)) : 0);
(libcuda == nullptr ? (libcuda = dlopen("/usr/local/lib/x86_64-linux-gnu/libcuda.so.1", RTLD_LAZY | RTLD_GLOBAL)) : 0);
(libcuda == nullptr ? (libcuda = dlopen("/usr/local/lib/x86_64-linux-gnu/libcuda.so", RTLD_LAZY | RTLD_GLOBAL)) : 0);
(libcuda == nullptr ? (libcuda = dlopen("/usr/local/cuda/lib64/libcuda.so.1", RTLD_LAZY | RTLD_GLOBAL)) : 0);
(libcuda == nullptr ? (libcuda = dlopen("/usr/local/cuda/lib64/libcuda.so", RTLD_LAZY | RTLD_GLOBAL)) : 0);
(libcuda == nullptr ? (libcuda = dlopen("/usr/local/cuda/lib64/stubs/libcuda.so", RTLD_LAZY | RTLD_GLOBAL)) : 0);

CHECK_NE(nullptr, libcuda);
CHECK_NE(nullptr, (cuModuleLoad = (decltype(cuModuleLoad))dlsym(libcuda, "cuModuleLoad")));
CHECK_NE(nullptr, (cuModuleGetFunction = (decltype(cuModuleGetFunction))dlsym(libcuda, "cuModuleGetFunction")));
CHECK_NE(nullptr, (cuLaunchKernel = (decltype(cuLaunchKernel))dlsym(libcuda, "cuLaunchKernel")));
}
#endif
#endif

int key_int = (_key & 255), ctx = _key >> 8;
if (ctx >= (int)gpuMods.size())
gpuMods.resize(ctx + 1);

auto &gm = gpuMods[ctx];
if (gm.hFunc == nullptr) {
std::string key = std::to_string(key_int);
std::string file_name = "/tmp/" + std::to_string(ctx) + "-" + key + ".cu";
FILE *fp = fopen(file_name.c_str(), "rb");
CHECK_EQ(true, fp != nullptr);
fseek(fp, 0, SEEK_END);
size_t code_size = ftell(fp);
fseek(fp, 0, SEEK_SET);
std::vector<char> code(code_size + 1);
CHECK_EQ(code_size, fread((void*)code.data(), 1, code_size, fp));
fclose(fp);

int dev = key_int;
CHECK_EQ(0, cudaSetDevice(dev));

#if !defined(__HIP_PLATFORM_HCC__)
std::string cc = "30";
int major, minor;
CHECK_EQ(0, cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, dev));
CHECK_EQ(0, cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, dev));
std::string arch = std::to_string(major) + std::to_string(minor);
CHECK_EQ(0, system(("/usr/local/cuda/bin/nvcc " + file_name + " -o " + file_name + ".fatbin --fatbin -O2 -gencode arch=compute_" + arch + ",code=sm_" + arch).c_str()));
#else
hipDeviceProp_t prop;
CHECK_EQ(0, hipGetDeviceProperties(&prop, dev));
std::string arch = std::to_string(prop.gcnArch);
CHECK_EQ(0, system(("/opt/rocm/bin/hipcc " + file_name + " -o " + file_name + ".fatbin --genco -O2 -w --amdgpu-target=gfx" + arch).c_str()));
#endif
CHECK_EQ(0, cuModuleLoad(&gm.hMod, (file_name + ".fatbin").c_str()));

const char *source = code.data(), *pos, *tail;
CHECK_EQ(true, nullptr != (pos = strstr(source, " void ")));
pos += 6; CHECK_EQ(true, nullptr != (tail = strchr(pos, '(')));

CHECK_EQ(0, cuModuleGetFunction(&gm.hFunc, gm.hMod, std::string(pos, tail - pos).c_str()));
CHECK_EQ(true, nullptr != gm.hFunc);

{ char tag[] = "// [thread_extent] blockIdx.x = "; pos = strstr(source, tag); gm.blocks.x = pos ? std::atoi(pos + sizeof(tag) - 1) : 1; }
{ char tag[] = "// [thread_extent] blockIdx.y = "; pos = strstr(source, tag); gm.blocks.y = pos ? std::atoi(pos + sizeof(tag) - 1) : 1; }
{ char tag[] = "// [thread_extent] blockIdx.z = "; pos = strstr(source, tag); gm.blocks.z = pos ? std::atoi(pos + sizeof(tag) - 1) : 1; }
{ char tag[] = "// [thread_extent] threadIdx.x = "; pos = strstr(source, tag); gm.threads.x = pos ? std::atoi(pos + sizeof(tag) - 1) : 1; }
{ char tag[] = "// [thread_extent] threadIdx.y = "; pos = strstr(source, tag); gm.threads.y = pos ? std::atoi(pos + sizeof(tag) - 1) : 1; }
{ char tag[] = "// [thread_extent] threadIdx.z = "; pos = strstr(source, tag); gm.threads.z = pos ? std::atoi(pos + sizeof(tag) - 1) : 1; }
}

std::vector<void*> pargs(ts.size()), ppargs(ts.size());
for (int i = 0; i < (int)ts.size(); ++i) {
pargs[i] = (void*)ts[i].data_ptr(), ppargs[i] = &pargs[i];
}

CHECK_EQ(0, cuLaunchKernel(gm.hFunc, gm.blocks.x, gm.blocks.y, gm.blocks.z, gm.threads.x, gm.threads.y, gm.threads.z, 0, nullptr, ppargs.data(), nullptr));
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("invoke",
&invoke,
"Generic Invoke (CUDA)"
);
}
3 changes: 3 additions & 0 deletions tutel/impls/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

Loading

0 comments on commit bb316e4

Please sign in to comment.