diff --git a/vllm/executor/torchrun_gpu_executor.py b/vllm/executor/torchrun_gpu_executor.py index 2a4a8be5ad40b..59a82ef61fc38 100644 --- a/vllm/executor/torchrun_gpu_executor.py +++ b/vllm/executor/torchrun_gpu_executor.py @@ -1,6 +1,8 @@ import os from typing import Dict, List, Optional +import torch + from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig, VisionLanguageConfig) @@ -8,7 +10,7 @@ from vllm.executor.gpu_executor import GPUExecutor from vllm.logger import init_logger from vllm.model_executor.parallel_utils.communication_op import ( - broadcast_object_list) + broadcast_object_list, tensor_model_parallel_all_gather) from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, make_async) @@ -63,6 +65,17 @@ def _init_worker(self): self.driver_worker.init_device() self.driver_worker.load_model() + def determine_num_available_blocks(self) -> tuple[int, int]: + num_gpu_blocks, num_cpu_blocks = ( + self.driver_worker.determine_num_available_blocks()) + t = torch.tensor( + [[num_gpu_blocks], [num_cpu_blocks]], + device="cuda", + dtype=torch.int32, + ) + output = tensor_model_parallel_all_gather(t) + return (torch.min(output[0]).item(), torch.min(output[1]).item()) + def execute_model(self, seq_group_metadata_list: List[SequenceGroupMetadata], blocks_to_swap_in: Dict[int, int],