Skip to content

Commit

Permalink
simplifying HF Chat Model, by making use of HF Chat Templates
Browse files Browse the repository at this point in the history
  • Loading branch information
eryk-dsai committed Oct 10, 2023
1 parent 7756700 commit d9e9ef3
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 86 deletions.
22 changes: 9 additions & 13 deletions docs/extras/integrations/chat/huggingface_pipeline.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,14 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"This notebook demonstrates the use of Hugging Face models as LangChain Chat models. We support Llama 2 Chat models out of the box because the prompt templates for instruction-tuned models differ from model to model. To handle any other Hugging Face model, simply create a class that inherits from the `ChatHuggingFacePipeline` class and implement a custom `format_messages_as_text` that parses the List of Messages to string."
"This notebook demonstrates how to use Hugging Face models as LangChain Chat models, using the Llama 2 Chat model as an example. We use the Hugging Face tokenizer's 'apply_chat_template' method to handle different instruction tuned models with different prompting templates. If you want to change the prompt templateing behavior, you can find instructions in the Hugging Face [guide](https://huggingface.co/docs/transformers/main/en/chat_templating)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Llama-2-Chat"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Setup"
"## Setup"
]
},
{
Expand All @@ -44,7 +37,7 @@
")\n",
"\n",
"# LangChain imports:\n",
"from langchain.chat_models import ChatHFLlama2Pipeline\n",
"from langchain.chat_models import ChatHuggingFacePipeline\n",
"from langchain.schema import AIMessage, HumanMessage, SystemMessage"
]
},
Expand Down Expand Up @@ -100,7 +93,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### Creating Hugging Face Pipeline instance:"
"## Creating Hugging Face Pipeline instance:"
]
},
{
Expand Down Expand Up @@ -176,6 +169,9 @@
],
"source": [
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
"# disabling the default System Message of the Llama model \n",
"tokenizer.use_default_system_prompt = False\n",
"\n",
"model_4bit = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=bnb_config, device_map=\"auto\")"
]
},
Expand All @@ -198,7 +194,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Initializing a LangChain Llama-2-Chat instance"
"Initializing the Chat Model instance"
]
},
{
Expand All @@ -207,7 +203,7 @@
"metadata": {},
"outputs": [],
"source": [
"chat = ChatHFLlama2Pipeline(pipeline=pipe)"
"chat = ChatHuggingFacePipeline(pipeline=pipe)"
]
},
{
Expand Down
4 changes: 2 additions & 2 deletions libs/langchain/langchain/chat_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from langchain.chat_models.fake import FakeListChatModel
from langchain.chat_models.fireworks import ChatFireworks
from langchain.chat_models.google_palm import ChatGooglePalm
from langchain.chat_models.huggingface_pipeline import ChatHFLlama2Pipeline
from langchain.chat_models.huggingface_pipeline import ChatHuggingFacePipeline
from langchain.chat_models.human import HumanInputChatModel
from langchain.chat_models.javelin_ai_gateway import ChatJavelinAIGateway
from langchain.chat_models.jinachat import JinaChat
Expand All @@ -51,7 +51,7 @@
"ChatGooglePalm",
"ChatMLflowAIGateway",
"ChatOllama",
"ChatHFLlama2Pipeline",
"ChatHuggingFacePipeline",
"ChatVertexAI",
"JinaChat",
"HumanInputChatModel",
Expand Down
105 changes: 34 additions & 71 deletions libs/langchain/langchain/chat_models/huggingface_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import importlib.util
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Dict, List, Optional, Union

from langchain.callbacks.manager import (
Expand All @@ -18,18 +16,38 @@
from langchain.schema.output import ChatGeneration


class ChatHuggingFacePipeline(BaseChatModel, ABC):
class ChatHuggingFacePipeline(BaseChatModel):
pipeline: Any

@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "huggingface_pipeline_chat"

@abstractmethod
def format_messages_as_text(self, messages: List[BaseMessage]) -> str:
"""Method for parsing the list of LangChain Messages into string"""
...
@staticmethod
def convert_lc_messages_to_hf_messages(
messages: List[BaseMessage],
) -> List[Dict[str, str]]:
"""
Method for converting the list of LangChain Messages into
format required by Hugging Face.
"""
output = []

for message in messages:
if isinstance(message, SystemMessage):
output.append({"role": "system", "content": message.content})
elif isinstance(message, HumanMessage):
output.append({"role": "user", "content": message.content})
elif isinstance(message, AIMessage):
output.append({"role": "assistant", "content": message.content})
else:
raise ValueError(
f"Unexpected message type: {type(message)}. "
"Expected one of [SystemMessage, HumanMessage, AIMessage]."
)

return output

@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
Expand All @@ -39,6 +57,13 @@ def validate_environment(cls, values: Dict) -> Dict:
):
raise ValueError("The pipeline task should be 'text-generation'.")

if not hasattr(values["pipeline"], "apply_chat_template"):
raise ValueError(
"Your transformers module might be outdated. "
"Please update it to ensure that tokenizer has the "
"'apply_chat_template' method."
)

return values

def _generate(
Expand All @@ -48,7 +73,8 @@ def _generate(
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
prompt = self.format_messages_as_text(messages)
chat = self.convert_lc_messages_to_hf_messages(messages)
prompt = self.pipeline.tokenizer.apply_chat_template(chat, tokenize=False)

# make sure that `return_full_text` is set to False
# otherwise, pipeline will return prompt + generation
Expand Down Expand Up @@ -134,66 +160,3 @@ def __call__(
message=AIMessage(content=response),
)
return ChatResult(generations=[chat_generation])


class ChatHFLlama2Pipeline(ChatHuggingFacePipeline):
class InstructionTokens(Enum):
def __str__(self) -> str:
return self.value

B_INST = "[INST]"
E_INST = "[/INST]"

class SystemTokens(Enum):
def __str__(self) -> str:
return self.value

B_SYS = "<<SYS>>"
E_SYS = "<</SYS>>"

def format_messages_as_text(self, messages: List[BaseMessage]) -> str:
"""
Transform List of Chat Messages to text following Meta's prompt guidelines.
Prompt template with System Message:
```
<s>[INST] <<SYS>>
{{ system_prompt }}
<</SYS>>
{{ user_msg_1 }} [/INST] {{ model_answer_1 }} </s>
```
Prompt template without System Message:
```
<s>[INST] {{ user_msg_1 }} [/INST] {{ model_answer_1 }} </s>
```
Source:
https://github.com/facebookresearch/llama-recipes/blob/df77625e48c3994aef19702fb331215f7fb83494/docs/inference.md?plain=1#L124
"""
prompt = ""

for i, message in enumerate(messages):
if isinstance(message, SystemMessage) and i != 0:
raise ValueError(
"SystemMessage can only appear as the first message in the list."
)
elif isinstance(message, SystemMessage) and i == 0:
prompt += (
f"<s>{self.InstructionTokens.B_INST} "
f"{self.SystemTokens.B_SYS}\n{message.content}\n"
f"{self.SystemTokens.E_SYS}\n\n"
)
elif isinstance(message, HumanMessage) and i > 0:
prompt += f"{message.content} {self.InstructionTokens.E_INST} "
elif isinstance(message, HumanMessage) and i == 0:
prompt += (
f"<s>{self.InstructionTokens.B_INST} "
f"{message.content} {self.InstructionTokens.E_INST} "
)
elif isinstance(message, AIMessage):
prompt += f"{message.content} </s><s>{self.InstructionTokens.B_INST} "
else:
raise ValueError(f"Unsupported Message type: {type(message)}")

return prompt

0 comments on commit d9e9ef3

Please sign in to comment.