Skip to content

Commit

Permalink
Load from 'model_state_dict' if present
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 665996907
  • Loading branch information
talumbau authored and copybara-github committed Aug 21, 2024
1 parent dad940b commit 6836309
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions ai_edge_torch/generative/utilities/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def load_pytorch_statedict(full_path: str):
patterns = []
if os.path.isdir(full_path):
patterns.append(os.path.join(full_path, "*.bin"))
patterns.append(os.path.join(full_path, "*.pt"))
patterns.append(os.path.join(full_path, "*pt"))
else:
patterns.append(full_path)
for pattern in patterns:
Expand Down Expand Up @@ -149,6 +149,7 @@ def load(
enabled.
"""
state = self._loader(self._file_name)
state = state["model_state_dict"] if "model_state_dict" in state else state
converted_state = dict()
if self._names.embedding is not None:
converted_state["tok_embedding.weight"] = state.pop(
Expand Down Expand Up @@ -200,7 +201,7 @@ def _get_loader(self) -> Callable[[str], Dict[str, torch.Tensor]]:
if glob.glob(os.path.join(self._file_name, "*.safetensors")):
return load_safetensors
if glob.glob(os.path.join(self._file_name, "*.bin")) or glob.glob(
os.path.join(self._file_name, "*.pt")
os.path.join(self._file_name, "*pt")
):
return load_pytorch_statedict

Expand Down

0 comments on commit 6836309

Please sign in to comment.