diff --git a/ai_edge_torch/generative/utilities/loader.py b/ai_edge_torch/generative/utilities/loader.py index b8891233..4d9e3832 100644 --- a/ai_edge_torch/generative/utilities/loader.py +++ b/ai_edge_torch/generative/utilities/loader.py @@ -72,7 +72,7 @@ def load_pytorch_statedict(full_path: str): patterns = [] if os.path.isdir(full_path): patterns.append(os.path.join(full_path, "*.bin")) - patterns.append(os.path.join(full_path, "*.pt")) + patterns.append(os.path.join(full_path, "*pt")) else: patterns.append(full_path) for pattern in patterns: @@ -149,6 +149,7 @@ def load( enabled. """ state = self._loader(self._file_name) + state = state["model_state_dict"] if "model_state_dict" in state else state converted_state = dict() if self._names.embedding is not None: converted_state["tok_embedding.weight"] = state.pop( @@ -200,7 +201,7 @@ def _get_loader(self) -> Callable[[str], Dict[str, torch.Tensor]]: if glob.glob(os.path.join(self._file_name, "*.safetensors")): return load_safetensors if glob.glob(os.path.join(self._file_name, "*.bin")) or glob.glob( - os.path.join(self._file_name, "*.pt") + os.path.join(self._file_name, "*pt") ): return load_pytorch_statedict