diff --git a/README.md b/README.md index c165cba..1437b2a 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # Tutel -Tutel MoE: An Optimized Mixture-of-Experts Implementation. +Tutel MoE: An Optimized Mixture-of-Experts Implementation, also the first parallel solution proposing ["No-penalty Parallism/Sparsity/Capacity/.. Switching"](https://mlsys.org/media/mlsys-2023/Slides/2477.pdf) for modern training and inference that have dynamic behaviors. - Supported Framework: Pytorch (recommend: >= 1.10) - Supported GPUs: CUDA(fp64/fp32/fp16/bfp16), ROCm(fp64/fp32/fp16) @@ -9,6 +9,24 @@ Tutel MoE: An Optimized Mixture-of-Experts Implementation. ### What's New: +- Tutel v0.3.2: Add tensorcore option for extra benchmarks / Extend the example for custom experts / Allow NCCL timeout settings: +```py + >> Example for using tensorcore: + + python3 -m tutel.examples.helloworld --dtype=float32 + python3 -m tutel.examples.helloworld --dtype=float32 --use_tensorcore + + python3 -m tutel.examples.helloworld --dtype=float16 + python3 -m tutel.examples.helloworld --dtype=float16 --use_tensorcore + + >> Example for custom experts: + python3 -m tutel.examples.helloworld_custom_expert --batch_size=16 + + >> Example for NCCL timeout settings: + TUTEL_GLOBAL_TIMEOUT_SEC=60 python3 -m torch.distributed.run --nproc_per_node=8 -m tutel.examples.helloworld --use_tensorcore + +``` + - Tutel v0.3.1: Add NCCL all_to_all_v and all_gather_v for arbitrary-length message transfers: ```py >> Example: @@ -84,7 +102,7 @@ Tutel MoE: An Optimized Mixture-of-Experts Implementation. $ python3 -m tutel.examples.helloworld_ddp --batch_size=16 # Test Tutel-optimized MoE + Pytorch DDP distribution (requires: Pytorch >= 1.8.0) $ python3 -m tutel.examples.helloworld_ddp_tutel --batch_size=16 # Test Tutel-optimized MoE + Tutel DDP distribution (ZeRO on optimizors) $ python3 -m tutel.examples.helloworld_amp --batch_size=16 # Test Tutel-optimized MoE with AMP data type + manual distribution - $ python3 -m tutel.examples.helloworld_demo --batch_size=16 # Test Tutel-optimized MoE + custom defined expert layer + $ python3 -m tutel.examples.helloworld_custom_expert --batch_size=16 # Test Tutel-optimized MoE + custom defined expert layer $ python3 -m tutel.examples.helloworld_from_scratch # Test Custom MoE implementation from scratch $ python3 -m tutel.examples.moe_mnist # Test MoE layer in end-to-end MNIST dataset $ python3 -m tutel.examples.moe_cifar10 # Test MoE layer in end-to-end CIFAR10 dataset diff --git a/tutel/examples/helloworld.py b/tutel/examples/helloworld.py index 5f50b77..b98c2f3 100755 --- a/tutel/examples/helloworld.py +++ b/tutel/examples/helloworld.py @@ -36,9 +36,13 @@ parser.add_argument('--eval', default=False, action='store_true') parser.add_argument('--capacity_factor', type=float, default=1.0) # 0.0 for dMoE (dropless-MoE), negative for no-padded capacity. parser.add_argument('--megablocks_size', type=int, default=0) +parser.add_argument('--use_tensorcore', default=False, action='store_true') args = parser.parse_args() +if args.use_tensorcore: + torch.backends.cuda.matmul.allow_tf32 = True + parallel_env = system.init_data_model_parallel(backend='nccl' if args.device == 'cuda' else 'gloo') dist_rank, dist_world_size, dist_print = parallel_env.global_rank, parallel_env.global_size, parallel_env.dist_print args.local_rank = parallel_env.local_device.index diff --git a/tutel/examples/helloworld_demo.py b/tutel/examples/helloworld_custom_expert.py old mode 100644 new mode 100755 similarity index 100% rename from tutel/examples/helloworld_demo.py rename to tutel/examples/helloworld_custom_expert.py