From ee2679836c3f239061e3c607a467fa7058a7d1ca Mon Sep 17 00:00:00 2001 From: Adriane Boyd Date: Fri, 11 Aug 2023 15:31:40 +0200 Subject: [PATCH] Add fallback to load state_dict with strict=False Due to incompatibilities related to `state_dict` keys between `transformers` v4.30 and v4.31, fall back to loading with `strict=False`. --- .github/workflows/tests.yml | 4 ---- spacy_transformers/layers/hf_shim.py | 23 +++++++++++++---------- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 2ac668f7..04324cfd 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -108,9 +108,5 @@ jobs: - name: Test backwards compatibility for v1.1 models if: matrix.python_version == '3.9' run: | - python -m pip install "transformers<4.31" python -m pip install https://github.com/explosion/spacy-models/releases/download/en_core_web_trf-3.4.0/en_core_web_trf-3.4.0-py3-none-any.whl --no-deps python -c "import spacy; nlp = spacy.load('en_core_web_trf'); doc = nlp('test')" - # NOTE: update requirements at the end of this step if any following - # steps are added in the future - # python -m pip install -U -r requirements.txt diff --git a/spacy_transformers/layers/hf_shim.py b/spacy_transformers/layers/hf_shim.py index 567030a8..90ef944e 100644 --- a/spacy_transformers/layers/hf_shim.py +++ b/spacy_transformers/layers/hf_shim.py @@ -122,19 +122,22 @@ def from_bytes(self, bytes_data): device = get_torch_default_device() try: self._model.load_state_dict(torch.load(filelike, map_location=device)) - except RuntimeError as ex: + except RuntimeError: warn_msg = ( - "Error loading saved torch model. If the error is related " - "to unexpected key(s) in state_dict, a possible workaround " - "is to load this model with 'transformers<4.31'. " - "Alternatively, download a newer compatible model or " - "retrain your custom model with the current " - "transformers and spacy-transformers versions. For more " - "details and available updates, run: python -m spacy " - "validate" + "Error loading saved torch state_dict with strict=True, " + "likely due to differences between 'transformers' " + "versions. Attempting to load with strict=False as a " + "fallback...\n\n" + "If you see errors or degraded performance, download a " + "newer compatible model or retrain your custom model with " + "the current 'transformers' and 'spacy-transformers' " + "versions. For more details and available updates, run: " + "python -m spacy validate" ) warnings.warn(warn_msg) - raise ex + filelike.seek(0) + b = torch.load(filelike, map_location=device) + self._model.load_state_dict(b, strict=False) self._model.to(device) else: self._hfmodel = HFObjects(