From 67c2d60d91887da85da010acd33ce4042499e41b Mon Sep 17 00:00:00 2001 From: sfc-gh-hazhang Date: Tue, 7 May 2024 11:16:20 +0000 Subject: [PATCH] run linting --- examples/save_state_dict.py | 4 +++- vllm/model_executor/model_loader/loader.py | 1 + vllm/worker/model_runner.py | 5 +++-- vllm/worker/worker.py | 2 +- 4 files changed, 8 insertions(+), 4 deletions(-) diff --git a/examples/save_state_dict.py b/examples/save_state_dict.py index 5dc9853f6..0d8400fb6 100644 --- a/examples/save_state_dict.py +++ b/examples/save_state_dict.py @@ -45,7 +45,9 @@ def main(args): Path(args.output).mkdir(exist_ok=True) # Dump worker states to output directory model_executor = llm.llm_engine.model_executor - model_executor._run_workers("save_model", path=args.output, max_size=5 * 1024 ** 3) + model_executor._run_workers("save_model", + path=args.output, + max_size=5 * 1024**3) # Copy metadata files to output directory for file in os.listdir(model_path): if not any( diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 33a2f6950..20da12bed 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -362,6 +362,7 @@ def load_model(self, *, model_config: ModelConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig) -> nn.Module: from safetensors.torch import load_file + from vllm.distributed import get_tensor_model_parallel_rank with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 997226f1f..ea73fb216 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -212,13 +212,14 @@ def load_model(self) -> None: "but the KV cache data type is not FP8. " "KV cache scaling factors will not be used.") - def save_model(self, path: str, max_size: int=None) -> None: + def save_model(self, path: str, max_size: int) -> None: from safetensors.torch import save_file + from vllm.distributed import get_tensor_model_parallel_rank rank = get_tensor_model_parallel_rank() idx = 0 size = 0 - params = {} + params: Dict[str, torch.Tensor] = {} for name, param in self.model.named_parameters(): param_size = param.nelement() * param.element_size() if max_size and size + param_size > max_size: diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 49e294425..707d221b4 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -117,7 +117,7 @@ def init_device(self) -> None: def load_model(self): self.model_runner.load_model() - def save_model(self, path: str, max_size: int=None) -> None: + def save_model(self, path: str, max_size: int) -> None: self.model_runner.save_model(path, max_size=max_size) @torch.inference_mode()