Skip to content

Commit

Permalink
Add distributed executor backend to benchmark scripts (opendatahub-io…
Browse files Browse the repository at this point in the history
  • Loading branch information
mawong-amd authored Aug 2, 2024
1 parent 3e480e9 commit 42b1b9a
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 25 deletions.
53 changes: 32 additions & 21 deletions benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,27 +19,30 @@ def main(args: argparse.Namespace):

# NOTE(woosuk): If the request cannot be processed in a single batch,
# the engine will automatically process the request in multiple batches.
llm = LLM(model=args.model,
speculative_model=args.speculative_model,
num_speculative_tokens=args.num_speculative_tokens,
tokenizer=args.tokenizer,
quantization=args.quantization,
quantized_weights_path=args.quantized_weights_path,
tensor_parallel_size=args.tensor_parallel_size,
trust_remote_code=args.trust_remote_code,
dtype=args.dtype,
enforce_eager=args.enforce_eager,
kv_cache_dtype=args.kv_cache_dtype,
quantization_param_path=args.quantization_param_path,
device=args.device,
ray_workers_use_nsight=args.ray_workers_use_nsight,
worker_use_ray=args.worker_use_ray,
use_v2_block_manager=args.use_v2_block_manager,
enable_chunked_prefill=args.enable_chunked_prefill,
download_dir=args.download_dir,
block_size=args.block_size,
disable_custom_all_reduce=args.disable_custom_all_reduce,
gpu_memory_utilization=args.gpu_memory_utilization)
llm = LLM(
model=args.model,
speculative_model=args.speculative_model,
num_speculative_tokens=args.num_speculative_tokens,
tokenizer=args.tokenizer,
quantization=args.quantization,
quantized_weights_path=args.quantized_weights_path,
tensor_parallel_size=args.tensor_parallel_size,
trust_remote_code=args.trust_remote_code,
dtype=args.dtype,
enforce_eager=args.enforce_eager,
kv_cache_dtype=args.kv_cache_dtype,
quantization_param_path=args.quantization_param_path,
device=args.device,
ray_workers_use_nsight=args.ray_workers_use_nsight,
worker_use_ray=args.worker_use_ray,
use_v2_block_manager=args.use_v2_block_manager,
enable_chunked_prefill=args.enable_chunked_prefill,
download_dir=args.download_dir,
block_size=args.block_size,
disable_custom_all_reduce=args.disable_custom_all_reduce,
gpu_memory_utilization=args.gpu_memory_utilization,
distributed_executor_backend=args.distributed_executor_backend,
)

sampling_params = SamplingParams(
n=args.n,
Expand Down Expand Up @@ -237,5 +240,13 @@ def run_to_completion(profile_dir: Optional[str] = None):
help='the fraction of GPU memory to be used for '
'the model executor, which can range from 0 to 1.'
'If unspecified, will use the default value of 0.9.')
parser.add_argument(
'--distributed-executor-backend',
choices=['ray', 'mp', 'torchrun'],
default=None,
help='Backend to use for distributed serving. When more than 1 GPU '
'is used, on CUDA this will be automatically set to "ray" if '
'installed or "mp" (multiprocessing) otherwise. On ROCm, this is '
'instead set to torchrun by default.')
args = parser.parse_args()
main(args)
15 changes: 13 additions & 2 deletions benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def run_vllm(
enable_prefix_caching: bool,
enable_chunked_prefill: bool,
max_num_batched_tokens: int,
distributed_executor_backend: Optional[str],
gpu_memory_utilization: float = 0.9,
worker_use_ray: bool = False,
download_dir: Optional[str] = None,
Expand All @@ -104,6 +105,7 @@ def run_vllm(
download_dir=download_dir,
enable_chunked_prefill=enable_chunked_prefill,
max_num_batched_tokens=max_num_batched_tokens,
distributed_executor_backend=distributed_executor_backend,
)

# Add the requests to the engine.
Expand Down Expand Up @@ -229,8 +231,9 @@ def main(args: argparse.Namespace):
args.max_model_len, args.enforce_eager, args.kv_cache_dtype,
args.quantization_param_path, args.device,
args.enable_prefix_caching, args.enable_chunked_prefill,
args.max_num_batched_tokens, args.gpu_memory_utilization,
args.worker_use_ray, args.download_dir)
args.max_num_batched_tokens, args.distributed_executor_backend,
args.gpu_memory_utilization, args.worker_use_ray,
args.download_dir)
elif args.backend == "hf":
assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
Expand Down Expand Up @@ -384,6 +387,14 @@ def main(args: argparse.Namespace):
type=str,
default=None,
help='Path to save the throughput results in JSON format.')
parser.add_argument(
'--distributed-executor-backend',
choices=['ray', 'mp', 'torchrun'],
default=None,
help='Backend to use for distributed serving. When more than 1 GPU '
'is used, on CUDA this will be automatically set to "ray" if '
'installed or "mp" (multiprocessing) otherwise. On ROCm, this is '
'instead set to torchrun by default.')
args = parser.parse_args()
if args.tokenizer is None:
args.tokenizer = args.model
Expand Down
2 changes: 1 addition & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,7 @@ def __init__(
if self.distributed_executor_backend is None and self.world_size > 1:
if is_hip():
logger.info("Using torchrun for multi-GPU on "
"ROCM platform. Use --worker-use-ray or "
"ROCm platform. Use --worker-use-ray or "
"--distributed-executor-backend={ray, mp} to "
"override")
if not os.environ.get("RANK"):
Expand Down
2 changes: 1 addition & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def add_cli_args(
help='Backend to use for distributed serving. When more than 1 GPU '
'is used, on CUDA this will be automatically set to "ray" if '
'installed or "mp" (multiprocessing) otherwise. On ROCm, this is '
'instead automatically set to torchrun.')
'instead set to torchrun by default.')
parser.add_argument(
'--worker-use-ray',
action='store_true',
Expand Down

0 comments on commit 42b1b9a

Please sign in to comment.