Skip to content

Commit

Permalink
Model inside image
Browse files Browse the repository at this point in the history
  • Loading branch information
phlobo committed Dec 5, 2024
1 parent db9f9a3 commit 3d5011a
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 1 deletion.
5 changes: 4 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ RUN conda install -n xmen -c conda-forge nmslib cymem murmurhash -y
# Install pip dependencies
RUN pip install --no-cache-dir -r requirements.txt

# Download the pre-trained models so they are cached in the Docker image
RUN python download_models.py

EXPOSE 5000

# Define the command to run the server with parameters
CMD ["conda", "run", "-n", "xmen", "python3", "run_snomed_german_recommender.py", "--no-gpu", "--port", "5000", "index"]
CMD ["conda", "run", "-n", "xmen", "python3", "run_snomed_german_recommender.py", "--no-gpu", "--port", "5000", "index", "--num_recs", "10"]
15 changes: 15 additions & 0 deletions download_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from xmen.linkers.model_wrapper import Model_Wrapper
from xmen.linkers import SapBERTLinker
from transformers import logging as tf_logging
import logging

logging.basicConfig(level=logging.INFO)
tf_logging.set_verbosity_info()

def download_models():
""" Downloads the Hugging Face models required for the project. """
Model_Wrapper().load_model(SapBERTLinker.CROSS_LINGUAL, use_cuda=False)

if __name__ == '__main__':

download_models()
8 changes: 8 additions & 0 deletions run_snomed_german_recommender.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@

from utils import handle_dates

from transformers import logging as tf_logging
import logging

logging.basicConfig(level=logging.INFO)
tf_logging.set_verbosity_info()

class xMENSNOMEDLinker(Classifier):
def __init__(self, linker: EntityLinker, top_k = 3):
self.linker = linker
Expand Down Expand Up @@ -48,8 +54,10 @@ def run():

# Suppress InconsistentVersionWarning from TF-IDF vectorizer
warnings.filterwarnings("ignore", category=InconsistentVersionWarning)
print("Loading xMEN SNOMED Linker...", flush=True)
linker = default_ensemble(args.index_base_path, cuda=args.gpu)

print("Starting Ariadne server...", flush=True)
server = Server()
server.add_classifier("xmen_snomed", xMENSNOMEDLinker(linker, top_k=args.num_recs))

Expand Down

0 comments on commit 3d5011a

Please sign in to comment.