From 2fae0a2e6c90516b1bd6d03f7f65e3e19127d726 Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Sun, 24 Dec 2023 14:30:15 +0100 Subject: [PATCH] add hf model download --- flair/file_utils.py | 38 +++++++++++++++++++++ flair/models/entity_mention_linking.py | 21 ++++++++++-- flair/models/sequence_tagger_model.py | 45 ++----------------------- tests/test_biomedical_entity_linking.py | 8 ++--- 4 files changed, 63 insertions(+), 49 deletions(-) diff --git a/flair/file_utils.py b/flair/file_utils.py index 7f0ba5f9e7..1a2acc884c 100644 --- a/flair/file_utils.py +++ b/flair/file_utils.py @@ -20,6 +20,7 @@ import torch from botocore import UNSIGNED from botocore.config import Config +from requests import HTTPError from tqdm import tqdm as _tqdm import flair @@ -143,6 +144,43 @@ def unzip_file(file: Union[str, Path], unzip_to: Union[str, Path]): zipObj.extractall(Path(unzip_to)) +def hf_download(model_name: str) -> str: + hf_model_name = "pytorch_model.bin" + revision = "main" + + if "@" in model_name: + model_name_split = model_name.split("@") + revision = model_name_split[-1] + model_name = model_name_split[0] + + # use model name as subfolder + model_folder = model_name.split("/", maxsplit=1)[1] if "/" in model_name else model_name + + # Lazy import + from huggingface_hub.file_download import hf_hub_download + + try: + return hf_hub_download( + repo_id=model_name, + filename=hf_model_name, + revision=revision, + library_name="flair", + library_version=flair.__version__, + cache_dir=flair.cache_root / "models" / model_folder, + ) + except HTTPError: + # output information + logger.error("-" * 80) + logger.error( + f"ERROR: The key '{model_name}' was neither found on the ModelHub nor is this a valid path to a file on your system!" + ) + logger.error(" -> Please check https://huggingface.co/models?filter=flair for all available models.") + logger.error(" -> Alternatively, point to a model file on your local drive.") + logger.error("-" * 80) + Path(flair.cache_root / "models" / model_folder).rmdir() # remove folder again if not valid + raise + + def unpack_file(file: Path, unpack_to: Path, mode: Optional[str] = None, keep: bool = True): """Unpacks an archive file to the given location. diff --git a/flair/models/entity_mention_linking.py b/flair/models/entity_mention_linking.py index 18907bb451..74099ae502 100644 --- a/flair/models/entity_mention_linking.py +++ b/flair/models/entity_mention_linking.py @@ -31,7 +31,7 @@ from flair.datasets.entity_linking import InMemoryEntityLinkingDictionary from flair.embeddings import DocumentEmbeddings, DocumentTFIDFEmbeddings, TransformerDocumentEmbeddings from flair.embeddings.base import load_embeddings -from flair.file_utils import cached_path +from flair.file_utils import cached_path, hf_download from flair.training_utils import Result logger = logging.getLogger("flair") @@ -772,7 +772,24 @@ def _fetch_model(model_name: str) -> str: if Path(model_name).exists(): return model_name - raise NotImplementedError + bio_base_repo = "helpmefindaname" + + hf_model_map = { + "bio-gene": f"{bio_base_repo}/flair-eml-sapbert-bc2gn-gene", + "bio-disease": f"{bio_base_repo}/flair-eml-sapbert-bc5cdr-disease", + "bio-chemical": f"{bio_base_repo}/flair-eml-sapbert-bc5cdr-chemical", + "bio-species": f"{bio_base_repo}/flair-eml-species-exact-match", + "bio-gene-exact-match": f"{bio_base_repo}/flair-eml-gene-exact-match", + "bio-disease-exact-match": f"{bio_base_repo}/flair-eml-disease-exact-match", + "bio-chemical-exact-match": f"{bio_base_repo}/flair-eml-chemical-exact-match", + "bio-species-exact-match": f"{bio_base_repo}/flair-eml-species-exact-match", + } + + if model_name in hf_model_map: + model_name = hf_model_map[model_name] + + return hf_download(model_name) + @classmethod def _init_model_with_state_dict(cls, state: Dict[str, Any], **kwargs) -> "EntityMentionLinker": diff --git a/flair/models/sequence_tagger_model.py b/flair/models/sequence_tagger_model.py index c6defd24a6..2f1bab67e5 100644 --- a/flair/models/sequence_tagger_model.py +++ b/flair/models/sequence_tagger_model.py @@ -14,7 +14,7 @@ from flair.data import Dictionary, Label, Sentence, Span, get_spans_from_bio from flair.datasets import DataLoader, FlairDatapointDataset from flair.embeddings import TokenEmbeddings -from flair.file_utils import cached_path, unzip_file +from flair.file_utils import cached_path, unzip_file, hf_download from flair.models.sequence_tagger_utils.crf import CRF from flair.models.sequence_tagger_utils.viterbi import ViterbiDecoder, ViterbiLoss from flair.training_utils import store_embeddings @@ -775,9 +775,7 @@ def _fetch_model(model_name) -> str: # get mapped name hf_model_name = huggingface_model_map[model_name] - # use mapped name instead - model_name = hf_model_name - get_from_model_hub = True + model_path = hf_download(hf_model_name) # if not, check if model key is remapped to direct download location. If so, download model elif model_name in hu_model_map: @@ -838,44 +836,7 @@ def _fetch_model(model_name) -> str: # for all other cases (not local file or special download location), use HF model hub else: - get_from_model_hub = True - - # if not a local file, get from model hub - if get_from_model_hub: - hf_model_name = "pytorch_model.bin" - revision = "main" - - if "@" in model_name: - model_name_split = model_name.split("@") - revision = model_name_split[-1] - model_name = model_name_split[0] - - # use model name as subfolder - model_folder = model_name.split("/", maxsplit=1)[1] if "/" in model_name else model_name - - # Lazy import - from huggingface_hub.file_download import hf_hub_download - - try: - model_path = hf_hub_download( - repo_id=model_name, - filename=hf_model_name, - revision=revision, - library_name="flair", - library_version=flair.__version__, - cache_dir=flair.cache_root / "models" / model_folder, - ) - except HTTPError: - # output information - log.error("-" * 80) - log.error( - f"ERROR: The key '{model_name}' was neither found on the ModelHub nor is this a valid path to a file on your system!" - ) - log.error(" -> Please check https://huggingface.co/models?filter=flair for all available models.") - log.error(" -> Alternatively, point to a model file on your local drive.") - log.error("-" * 80) - Path(flair.cache_root / "models" / model_folder).rmdir() # remove folder again if not valid - raise + model_path = hf_download(model_name) return model_path diff --git a/tests/test_biomedical_entity_linking.py b/tests/test_biomedical_entity_linking.py index b4b4475ec6..7a165041d1 100644 --- a/tests/test_biomedical_entity_linking.py +++ b/tests/test_biomedical_entity_linking.py @@ -1,5 +1,5 @@ from flair.data import Sentence -from flair.models.biomedical_entity_linking import ( +from flair.models.entity_mention_linking import ( EntityMentionLinker, load_dictionary, ) @@ -51,11 +51,11 @@ def test_biomedical_entity_linking(): tagger = Classifier.load("hunflair") tagger.predict(sentence) - disease_linker = EntityMentionLinker.build("diseases", "diseases-nel", hybrid_search=True) + disease_linker = EntityMentionLinker.load("bio-disease") disease_dictionary = disease_linker.dictionary disease_linker.predict(sentence) - gene_linker = EntityMentionLinker.build("genes", "genes-nel", hybrid_search=False, entity_type="genes") + gene_linker = EntityMentionLinker.load("bio-genes") gene_dictionary = gene_linker.dictionary gene_linker.predict(sentence) @@ -73,5 +73,3 @@ def test_biomedical_entity_linking(): for candidate_label in span.get_labels(gene_linker.label_type): candidate = gene_dictionary[candidate_label.value] print(f"Candidate: {candidate.concept_name}") - - breakpoint() # noqa: T100