diff --git a/blendsql/models/local/_transformers.py b/blendsql/models/local/_transformers.py index 97e4b941..aeb5528c 100644 --- a/blendsql/models/local/_transformers.py +++ b/blendsql/models/local/_transformers.py @@ -1,6 +1,8 @@ import importlib.util from typing import Optional +from colorama import Fore +from ..._logger import logger from .._model import LocalModel, ModelObj DEFAULT_KWARGS = {"do_sample": True, "temperature": 0.0, "top_p": 1.0} @@ -47,6 +49,7 @@ def __init__( transformers.logging.set_verbosity_error() if config is None: config = {} + super().__init__( model_name_or_path=model_name_or_path, requires_config=False, @@ -61,12 +64,20 @@ def _load_model(self) -> ModelObj: from guidance.models import Transformers import torch - return Transformers( + lm = Transformers( self.model_name_or_path, echo=False, device_map="cuda" if torch.cuda.is_available() else "cpu", **self.load_model_kwargs, ) + # Try to infer if we're in chat mode + if lm.engine.tokenizer._orig_tokenizer.chat_template is None: + logger.debug( + Fore.YELLOW + + "chat_template not found in tokenizer config.\nBlendSQL currently only works with chat models" + + Fore.RESET + ) + return lm class TransformersVisionModel(TransformersLLM):