From 6836309765fc72e9aed6d5c6954699eeafa40c11 Mon Sep 17 00:00:00 2001 From: "T.J. Alumbaugh" Date: Wed, 21 Aug 2024 12:58:30 -0700 Subject: [PATCH] Load from 'model_state_dict' if present PiperOrigin-RevId: 665996907 --- ai_edge_torch/generative/utilities/loader.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ai_edge_torch/generative/utilities/loader.py b/ai_edge_torch/generative/utilities/loader.py index b8891233..4d9e3832 100644 --- a/ai_edge_torch/generative/utilities/loader.py +++ b/ai_edge_torch/generative/utilities/loader.py @@ -72,7 +72,7 @@ def load_pytorch_statedict(full_path: str): patterns = [] if os.path.isdir(full_path): patterns.append(os.path.join(full_path, "*.bin")) - patterns.append(os.path.join(full_path, "*.pt")) + patterns.append(os.path.join(full_path, "*pt")) else: patterns.append(full_path) for pattern in patterns: @@ -149,6 +149,7 @@ def load( enabled. """ state = self._loader(self._file_name) + state = state["model_state_dict"] if "model_state_dict" in state else state converted_state = dict() if self._names.embedding is not None: converted_state["tok_embedding.weight"] = state.pop( @@ -200,7 +201,7 @@ def _get_loader(self) -> Callable[[str], Dict[str, torch.Tensor]]: if glob.glob(os.path.join(self._file_name, "*.safetensors")): return load_safetensors if glob.glob(os.path.join(self._file_name, "*.bin")) or glob.glob( - os.path.join(self._file_name, "*.pt") + os.path.join(self._file_name, "*pt") ): return load_pytorch_statedict