Skip to content

Commit

Permalink
add hf model download
Browse files Browse the repository at this point in the history
  • Loading branch information
helpmefindaname committed Dec 24, 2023
1 parent d808e8a commit 2fae0a2
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 49 deletions.
38 changes: 38 additions & 0 deletions flair/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import torch
from botocore import UNSIGNED
from botocore.config import Config
from requests import HTTPError
from tqdm import tqdm as _tqdm

import flair
Expand Down Expand Up @@ -143,6 +144,43 @@ def unzip_file(file: Union[str, Path], unzip_to: Union[str, Path]):
zipObj.extractall(Path(unzip_to))


def hf_download(model_name: str) -> str:
hf_model_name = "pytorch_model.bin"
revision = "main"

if "@" in model_name:
model_name_split = model_name.split("@")
revision = model_name_split[-1]
model_name = model_name_split[0]

# use model name as subfolder
model_folder = model_name.split("/", maxsplit=1)[1] if "/" in model_name else model_name

# Lazy import
from huggingface_hub.file_download import hf_hub_download

try:
return hf_hub_download(
repo_id=model_name,
filename=hf_model_name,
revision=revision,
library_name="flair",
library_version=flair.__version__,
cache_dir=flair.cache_root / "models" / model_folder,
)
except HTTPError:
# output information
logger.error("-" * 80)
logger.error(
f"ERROR: The key '{model_name}' was neither found on the ModelHub nor is this a valid path to a file on your system!"
)
logger.error(" -> Please check https://huggingface.co/models?filter=flair for all available models.")
logger.error(" -> Alternatively, point to a model file on your local drive.")
logger.error("-" * 80)
Path(flair.cache_root / "models" / model_folder).rmdir() # remove folder again if not valid
raise


def unpack_file(file: Path, unpack_to: Path, mode: Optional[str] = None, keep: bool = True):
"""Unpacks an archive file to the given location.
Expand Down
21 changes: 19 additions & 2 deletions flair/models/entity_mention_linking.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from flair.datasets.entity_linking import InMemoryEntityLinkingDictionary
from flair.embeddings import DocumentEmbeddings, DocumentTFIDFEmbeddings, TransformerDocumentEmbeddings
from flair.embeddings.base import load_embeddings
from flair.file_utils import cached_path
from flair.file_utils import cached_path, hf_download
from flair.training_utils import Result

logger = logging.getLogger("flair")
Expand Down Expand Up @@ -772,7 +772,24 @@ def _fetch_model(model_name: str) -> str:
if Path(model_name).exists():
return model_name

raise NotImplementedError
bio_base_repo = "helpmefindaname"

hf_model_map = {
"bio-gene": f"{bio_base_repo}/flair-eml-sapbert-bc2gn-gene",
"bio-disease": f"{bio_base_repo}/flair-eml-sapbert-bc5cdr-disease",
"bio-chemical": f"{bio_base_repo}/flair-eml-sapbert-bc5cdr-chemical",
"bio-species": f"{bio_base_repo}/flair-eml-species-exact-match",
"bio-gene-exact-match": f"{bio_base_repo}/flair-eml-gene-exact-match",
"bio-disease-exact-match": f"{bio_base_repo}/flair-eml-disease-exact-match",
"bio-chemical-exact-match": f"{bio_base_repo}/flair-eml-chemical-exact-match",
"bio-species-exact-match": f"{bio_base_repo}/flair-eml-species-exact-match",
}

if model_name in hf_model_map:
model_name = hf_model_map[model_name]

return hf_download(model_name)


@classmethod
def _init_model_with_state_dict(cls, state: Dict[str, Any], **kwargs) -> "EntityMentionLinker":
Expand Down
45 changes: 3 additions & 42 deletions flair/models/sequence_tagger_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from flair.data import Dictionary, Label, Sentence, Span, get_spans_from_bio
from flair.datasets import DataLoader, FlairDatapointDataset
from flair.embeddings import TokenEmbeddings
from flair.file_utils import cached_path, unzip_file
from flair.file_utils import cached_path, unzip_file, hf_download
from flair.models.sequence_tagger_utils.crf import CRF
from flair.models.sequence_tagger_utils.viterbi import ViterbiDecoder, ViterbiLoss
from flair.training_utils import store_embeddings
Expand Down Expand Up @@ -775,9 +775,7 @@ def _fetch_model(model_name) -> str:
# get mapped name
hf_model_name = huggingface_model_map[model_name]

# use mapped name instead
model_name = hf_model_name
get_from_model_hub = True
model_path = hf_download(hf_model_name)

# if not, check if model key is remapped to direct download location. If so, download model
elif model_name in hu_model_map:
Expand Down Expand Up @@ -838,44 +836,7 @@ def _fetch_model(model_name) -> str:

# for all other cases (not local file or special download location), use HF model hub
else:
get_from_model_hub = True

# if not a local file, get from model hub
if get_from_model_hub:
hf_model_name = "pytorch_model.bin"
revision = "main"

if "@" in model_name:
model_name_split = model_name.split("@")
revision = model_name_split[-1]
model_name = model_name_split[0]

# use model name as subfolder
model_folder = model_name.split("/", maxsplit=1)[1] if "/" in model_name else model_name

# Lazy import
from huggingface_hub.file_download import hf_hub_download

try:
model_path = hf_hub_download(
repo_id=model_name,
filename=hf_model_name,
revision=revision,
library_name="flair",
library_version=flair.__version__,
cache_dir=flair.cache_root / "models" / model_folder,
)
except HTTPError:
# output information
log.error("-" * 80)
log.error(
f"ERROR: The key '{model_name}' was neither found on the ModelHub nor is this a valid path to a file on your system!"
)
log.error(" -> Please check https://huggingface.co/models?filter=flair for all available models.")
log.error(" -> Alternatively, point to a model file on your local drive.")
log.error("-" * 80)
Path(flair.cache_root / "models" / model_folder).rmdir() # remove folder again if not valid
raise
model_path = hf_download(model_name)

return model_path

Expand Down
8 changes: 3 additions & 5 deletions tests/test_biomedical_entity_linking.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from flair.data import Sentence
from flair.models.biomedical_entity_linking import (
from flair.models.entity_mention_linking import (
EntityMentionLinker,
load_dictionary,
)
Expand Down Expand Up @@ -51,11 +51,11 @@ def test_biomedical_entity_linking():
tagger = Classifier.load("hunflair")
tagger.predict(sentence)

disease_linker = EntityMentionLinker.build("diseases", "diseases-nel", hybrid_search=True)
disease_linker = EntityMentionLinker.load("bio-disease")
disease_dictionary = disease_linker.dictionary
disease_linker.predict(sentence)

gene_linker = EntityMentionLinker.build("genes", "genes-nel", hybrid_search=False, entity_type="genes")
gene_linker = EntityMentionLinker.load("bio-genes")
gene_dictionary = gene_linker.dictionary

gene_linker.predict(sentence)
Expand All @@ -73,5 +73,3 @@ def test_biomedical_entity_linking():
for candidate_label in span.get_labels(gene_linker.label_type):
candidate = gene_dictionary[candidate_label.value]
print(f"Candidate: {candidate.concept_name}")

breakpoint() # noqa: T100

0 comments on commit 2fae0a2

Please sign in to comment.