diff --git a/lib/model/generic_transformer.py b/lib/model/generic_transformer.py index 16430db..193365a 100644 --- a/lib/model/generic_transformer.py +++ b/lib/model/generic_transformer.py @@ -12,9 +12,10 @@ def __init__(self, model_name: str): Load specified model name from subclass constant as HuggingFace transformer. """ self.model = None - self.model_name = model_name - if model_name: - self.model = SentenceTransformer(model_name, cache_folder=os.getenv("MODEL_DIR", "./models")) + self.model_name = os.environ.get("MODEL_NAME") + self.internal_model_name = model_name + if self.internal_model_name: + self.model = SentenceTransformer(self.internal_model_name, cache_folder=os.getenv("MODEL_DIR", "./models")) def respond(self, docs: Union[List[schemas.Message], schemas.Message]) -> List[schemas.GenericItem]: """