Skip to content

Commit

Permalink
Add fallback to load state_dict with strict=False
Browse files Browse the repository at this point in the history
Due to incompatibilities related to `state_dict` keys between
`transformers` v4.30 and v4.31, fall back to loading with
`strict=False`.
  • Loading branch information
adrianeboyd committed Aug 11, 2023
1 parent dcfb779 commit ee26798
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 14 deletions.
4 changes: 0 additions & 4 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,5 @@ jobs:
- name: Test backwards compatibility for v1.1 models
if: matrix.python_version == '3.9'
run: |
python -m pip install "transformers<4.31"
python -m pip install https://github.com/explosion/spacy-models/releases/download/en_core_web_trf-3.4.0/en_core_web_trf-3.4.0-py3-none-any.whl --no-deps
python -c "import spacy; nlp = spacy.load('en_core_web_trf'); doc = nlp('test')"
# NOTE: update requirements at the end of this step if any following
# steps are added in the future
# python -m pip install -U -r requirements.txt
23 changes: 13 additions & 10 deletions spacy_transformers/layers/hf_shim.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,19 +122,22 @@ def from_bytes(self, bytes_data):
device = get_torch_default_device()
try:
self._model.load_state_dict(torch.load(filelike, map_location=device))
except RuntimeError as ex:
except RuntimeError:
warn_msg = (
"Error loading saved torch model. If the error is related "
"to unexpected key(s) in state_dict, a possible workaround "
"is to load this model with 'transformers<4.31'. "
"Alternatively, download a newer compatible model or "
"retrain your custom model with the current "
"transformers and spacy-transformers versions. For more "
"details and available updates, run: python -m spacy "
"validate"
"Error loading saved torch state_dict with strict=True, "
"likely due to differences between 'transformers' "
"versions. Attempting to load with strict=False as a "
"fallback...\n\n"
"If you see errors or degraded performance, download a "
"newer compatible model or retrain your custom model with "
"the current 'transformers' and 'spacy-transformers' "
"versions. For more details and available updates, run: "
"python -m spacy validate"
)
warnings.warn(warn_msg)
raise ex
filelike.seek(0)
b = torch.load(filelike, map_location=device)
self._model.load_state_dict(b, strict=False)
self._model.to(device)
else:
self._hfmodel = HFObjects(
Expand Down

0 comments on commit ee26798

Please sign in to comment.