Skip to content

Commit

Permalink
run linting
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-hazhang committed May 7, 2024
1 parent 332890b commit 67c2d60
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 4 deletions.
4 changes: 3 additions & 1 deletion examples/save_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ def main(args):
Path(args.output).mkdir(exist_ok=True)
# Dump worker states to output directory
model_executor = llm.llm_engine.model_executor
model_executor._run_workers("save_model", path=args.output, max_size=5 * 1024 ** 3)
model_executor._run_workers("save_model",
path=args.output,
max_size=5 * 1024**3)
# Copy metadata files to output directory
for file in os.listdir(model_path):
if not any(
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,7 @@ def load_model(self, *, model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig) -> nn.Module:
from safetensors.torch import load_file

from vllm.distributed import get_tensor_model_parallel_rank
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
Expand Down
5 changes: 3 additions & 2 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,13 +212,14 @@ def load_model(self) -> None:
"but the KV cache data type is not FP8. "
"KV cache scaling factors will not be used.")

def save_model(self, path: str, max_size: int=None) -> None:
def save_model(self, path: str, max_size: int) -> None:
from safetensors.torch import save_file

from vllm.distributed import get_tensor_model_parallel_rank
rank = get_tensor_model_parallel_rank()
idx = 0
size = 0
params = {}
params: Dict[str, torch.Tensor] = {}
for name, param in self.model.named_parameters():
param_size = param.nelement() * param.element_size()
if max_size and size + param_size > max_size:
Expand Down
2 changes: 1 addition & 1 deletion vllm/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def init_device(self) -> None:
def load_model(self):
self.model_runner.load_model()

def save_model(self, path: str, max_size: int=None) -> None:
def save_model(self, path: str, max_size: int) -> None:
self.model_runner.save_model(path, max_size=max_size)

@torch.inference_mode()
Expand Down

0 comments on commit 67c2d60

Please sign in to comment.