diff --git a/flair/embeddings/token.py b/flair/embeddings/token.py index 18abeccb2..b139f97ff 100644 --- a/flair/embeddings/token.py +++ b/flair/embeddings/token.py @@ -2,7 +2,7 @@ import logging import re import tempfile -from collections import Counter, Mapping +from collections import Counter from pathlib import Path from typing import Any, Dict, List, Optional, Union @@ -1283,6 +1283,8 @@ def __init__( self.static_embeddings = True self.__embedding_length: int = 300 self.language_embeddings: Dict[str, Any] = {} + (KeyedVectors,) = lazy_import("word-embeddings", "gensim.models", "KeyedVectors") + self.kv = KeyedVectors super().__init__() self.eval() @@ -1345,7 +1347,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: embeddings_file = cached_path(f"{hu_path}/muse.{language_code}.vec.gensim", cache_dir=cache_dir) # load the model - self.language_embeddings[language_code] = gensim.models.KeyedVectors.load(str(embeddings_file)) + self.language_embeddings[language_code] = self.kv.load(str(embeddings_file)) for token, _token_idx in zip(sentence.tokens, range(len(sentence.tokens))): word_embedding = self.get_cached_vec(language_code=language_code, word=token.text) @@ -1401,7 +1403,7 @@ def __init__( else: if not language and model_file_path is None: raise ValueError("Need to specify model_file_path if no language is give in BytePairEmbeddings") - BPEmb, = lazy_import("word-embeddings", "bpemb", "BPEmb") + (BPEmb,) = lazy_import("word-embeddings", "bpemb", "BPEmb") if language: self.name: str = f"bpe-{language}-{syllables}-{dim}" @@ -1504,7 +1506,14 @@ def from_params(cls, params): else: embedding_file_path = None dim = params["dim"] - return cls(name=params["name"], dim=dim, model_file_path=model_file_path, embedding_file_path=embedding_file_path, field=params.get("field"), preprocess=params.get("preprocess", True)) + return cls( + name=params["name"], + dim=dim, + model_file_path=model_file_path, + embedding_file_path=embedding_file_path, + field=params.get("field"), + preprocess=params.get("preprocess", True), + ) def to_params(self): return { @@ -1541,7 +1550,7 @@ def state_dict(self, *args, **kwargs): return super().state_dict(*args, **kwargs) def _load_from_state_dict( - self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ): if not state_dict: # old embeddings do not have a torch-embedding and therefore do not store the weights in the saved torch state_dict