Skip to content

Commit

Permalink
Merge pull request #693 from dimagi/fr/chat-history
Browse files Browse the repository at this point in the history
Pipeline Chat History
  • Loading branch information
proteusvacuum authored Oct 10, 2024
2 parents 35fb851 + b6d8abc commit b38dc02
Show file tree
Hide file tree
Showing 13 changed files with 811 additions and 33 deletions.
4 changes: 3 additions & 1 deletion apps/chat/bots.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,5 +283,7 @@ def __init__(self, session: ExperimentSession):
self.session = session

def process_input(self, user_input: str, save_input_to_history=True, attachments: list["Attachment"] | None = None):
output = self.experiment.pipeline.invoke(PipelineState(messages=[user_input]), self.session)
output = self.experiment.pipeline.invoke(
PipelineState(messages=[user_input], experiment_session=self.session), self.session
)
return output["messages"][-1]
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Generated by Django 5.1 on 2024-10-10 12:26

import django.db.models.deletion
from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
('experiments', '0095_experiment_debug_mode_enabled'),
('pipelines', '0005_auto_20240802_0039'),
]

operations = [
migrations.CreateModel(
name='PipelineChatHistory',
fields=[
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('created_at', models.DateTimeField(auto_now_add=True)),
('updated_at', models.DateTimeField(auto_now=True)),
('type', models.CharField(choices=[('node', 'Node History'), ('named', 'Named History'), ('global', 'Global History'), ('none', 'No History')], max_length=10)),
('name', models.CharField(db_index=True, max_length=128)),
('session', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='pipeline_chat_history', to='experiments.experimentsession')),
],
options={
'ordering': ['-created_at'],
},
),
migrations.CreateModel(
name='PipelineChatMessages',
fields=[
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('created_at', models.DateTimeField(auto_now_add=True)),
('updated_at', models.DateTimeField(auto_now=True)),
('human_message', models.TextField()),
('ai_message', models.TextField()),
('chat_history', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='messages', to='pipelines.pipelinechathistory')),
],
options={
'abstract': False,
},
),
migrations.AddConstraint(
model_name='pipelinechathistory',
constraint=models.UniqueConstraint(fields=('session', 'type', 'name'), name='unique_session_type_name'),
),
]
33 changes: 30 additions & 3 deletions apps/pipelines/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,15 +98,13 @@ def invoke(
pipeline_run = self._create_pipeline_run(input, session)
logging_callback = PipelineLoggingCallbackHandler(pipeline_run)

if save_run_to_history and session is not None:
self._save_message_to_history(session, input["messages"][-1], ChatMessageType.HUMAN)

logging_callback.logger.debug("Starting pipeline run", input=input["messages"][-1])
try:
output = runnable.invoke(input, config=RunnableConfig(callbacks=[logging_callback]))
output = PipelineState(**output).json_safe()
pipeline_run.output = output
if save_run_to_history and session is not None:
self._save_message_to_history(session, input["messages"][-1], ChatMessageType.HUMAN)
self._save_message_to_history(session, output["messages"][-1], ChatMessageType.AI)
finally:
if pipeline_run.status == PipelineRunStatus.ERROR:
Expand Down Expand Up @@ -182,3 +180,32 @@ class PipelineEventInputs(models.TextChoices):
FULL_HISTORY = "full_history", "Full History"
HISTORY_LAST_SUMMARY = "history_last_summary", "History to last summary"
LAST_MESSAGE = "last_message", "Last message"


class PipelineChatHistoryTypes(models.TextChoices):
NODE = "node", "Node History"
NAMED = "named", "Named History"
GLOBAL = "global", "Global History"
NONE = "none", "No History"


class PipelineChatHistory(BaseModel):
session = models.ForeignKey(ExperimentSession, on_delete=models.CASCADE, related_name="pipeline_chat_history")

type = models.CharField(max_length=10, choices=PipelineChatHistoryTypes.choices)
name = models.CharField(max_length=128, db_index=True) # Either the name of the named history, or the node id

class Meta:
constraints = [
models.UniqueConstraint(fields=("session", "type", "name"), name="unique_session_type_name"),
]
ordering = ["-created_at"]


class PipelineChatMessages(BaseModel):
chat_history = models.ForeignKey(PipelineChatHistory, on_delete=models.CASCADE, related_name="messages")
human_message = models.TextField()
ai_message = models.TextField()

def as_tuples(self):
return [(ChatMessageType.HUMAN.value, self.human_message), (ChatMessageType.AI.value, self.ai_message)]
4 changes: 2 additions & 2 deletions apps/pipelines/nodes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,15 @@ def process(self, node_id: str, incoming_edges: list, state: PipelineState, conf
break
else: # This is the first node in the graph
input = state["messages"][-1]
output = self._process(input, state)
output = self._process(input=input, state=state, node_id=node_id)
# Append the output to the state, otherwise do not change the state
return (
PipelineState(messages=[output], outputs={node_id: output})
if output
else PipelineState(outputs={node_id: output})
)

def _process(self, input: str, state: PipelineState) -> PipelineState:
def _process(self, input: str, state: PipelineState, node_id: str) -> PipelineState:
"""The method that executes node specific functionality"""
raise NotImplementedError

Expand Down
86 changes: 69 additions & 17 deletions apps/pipelines/nodes/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,27 @@
from django.utils import timezone
from jinja2 import meta
from jinja2.sandbox import SandboxedEnvironment
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.messages import BaseMessage
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from langchain_text_splitters import RecursiveCharacterTextSplitter
from pydantic import Field, create_model
from pydantic import BaseModel, Field, create_model

from apps.channels.models import ChannelPlatform
from apps.experiments.models import ParticipantData, SourceMaterial
from apps.chat.conversation import compress_chat_history
from apps.experiments.models import ExperimentSession, ParticipantData, SourceMaterial
from apps.pipelines.exceptions import PipelineNodeBuildError
from apps.pipelines.models import PipelineChatHistory, PipelineChatHistoryTypes
from apps.pipelines.nodes.base import PipelineNode, PipelineState
from apps.pipelines.nodes.types import (
HistoryName,
HistoryType,
Keywords,
LlmModel,
LlmProviderId,
LlmTemperature,
MaxTokenLimit,
NumOutputs,
PipelineJinjaTemplate,
Prompt,
Expand All @@ -34,7 +40,7 @@ class RenderTemplate(PipelineNode):
__human_name__ = "Render a template"
template_string: PipelineJinjaTemplate

def _process(self, input, state: PipelineState) -> PipelineState:
def _process(self, input, **kwargs) -> PipelineState:
def all_variables(in_):
return {var: in_ for var in meta.find_undeclared_variables(env.parse(self.template_string))}

Expand All @@ -56,10 +62,13 @@ def all_variables(in_):
return template.render(content)


class LLMResponseMixin:
class LLMResponseMixin(BaseModel):
llm_provider_id: LlmProviderId
llm_model: LlmModel
llm_temperature: LlmTemperature = 1.0
history_type: HistoryType = PipelineChatHistoryTypes.NONE
history_name: HistoryName | None = None
max_token_limit: MaxTokenLimit = 8192

def get_llm_service(self):
from apps.service_providers.models import LlmProvider
Expand All @@ -79,7 +88,7 @@ def get_chat_model(self):
class LLMResponse(PipelineNode, LLMResponseMixin):
__human_name__ = "LLM response"

def _process(self, input, state: PipelineState) -> PipelineState:
def _process(self, input, **kwargs) -> PipelineState:
llm = self.get_chat_model()
output = llm.invoke(input, config=self._config)
return output.content
Expand All @@ -91,19 +100,22 @@ class LLMResponseWithPrompt(LLMResponse):
source_material_id: SourceMaterialId | None = None
prompt: Prompt = "You are a helpful assistant. Answer the user's query as best you can: {input}"

def _process(self, input, state: PipelineState) -> PipelineState:
prompt = PromptTemplate.from_template(template=self.prompt)
context = self._get_context(input, state, prompt)
def _process(self, input, state: PipelineState, node_id: str) -> PipelineState:
prompt = ChatPromptTemplate.from_messages(
[("system", self.prompt), MessagesPlaceholder("history", optional=True), ("human", "{input}")]
)
context = self._get_context(input, state, prompt, node_id)
if self.history_type != PipelineChatHistoryTypes.NONE:
input_messages = prompt.invoke(context).to_messages()
context["history"] = self._get_history(state["experiment_session"], node_id, input_messages)
chain = prompt | super().get_chat_model()
output = chain.invoke(context, config=self._config)
self._save_history(state["experiment_session"], node_id, input, output.content)
return output.content

def _get_context(self, input, state: PipelineState, prompt: PromptTemplate):
session = state["experiment_session"]
context = {}

if "input" in prompt.input_variables:
context["input"] = input
def _get_context(self, input, state: PipelineState, prompt: ChatPromptTemplate, node_id: str):
session: ExperimentSession = state["experiment_session"]
context = {"input": input}

if "source_material" in prompt.input_variables and self.source_material_id is None:
raise PipelineNodeBuildError("No source material set, but the prompt expects it")
Expand All @@ -118,6 +130,46 @@ def _get_context(self, input, state: PipelineState, prompt: PromptTemplate):

return context

def _get_history_name(self, node_id):
if self.history_type == PipelineChatHistoryTypes.NAMED:
return self.history_name
return node_id

def _get_history(self, session: ExperimentSession, node_id: str, input_messages: list) -> list:
if self.history_type == PipelineChatHistoryTypes.NONE:
return []

if self.history_type == PipelineChatHistoryTypes.GLOBAL:
return compress_chat_history(
chat=session.chat,
llm=self.get_chat_model(),
max_token_limit=self.max_token_limit,
input_messages=input_messages,
)

try:
history: PipelineChatHistory = session.pipeline_chat_history.get(
type=self.history_type, name=self._get_history_name(node_id)
)
except PipelineChatHistory.DoesNotExist:
return []
message_pairs = history.messages.all()
return [message for message_pair in message_pairs for message in message_pair.as_tuples()]

def _save_history(self, session: ExperimentSession, node_id: str, human_message: str, ai_message: str):
if self.history_type == PipelineChatHistoryTypes.NONE:
return

if self.history_type == PipelineChatHistoryTypes.GLOBAL:
# Global History is saved outside of the node
return

history, _ = session.pipeline_chat_history.get_or_create(
type=self.history_type, name=self._get_history_name(node_id)
)
message = history.messages.create(human_message=human_message, ai_message=ai_message)
return message

def _get_participant_data(self, session):
if session.experiment_channel.platform == ChannelPlatform.WEB and session.participant.user is None:
return ""
Expand All @@ -138,7 +190,7 @@ class SendEmail(PipelineNode):
recipient_list: str
subject: str

def _process(self, input, state: PipelineState) -> PipelineState:
def _process(self, input, **kwargs) -> PipelineState:
send_email_from_pipeline.delay(
recipient_list=self.recipient_list.split(","), subject=self.subject, message=input
)
Expand All @@ -147,7 +199,7 @@ def _process(self, input, state: PipelineState) -> PipelineState:
class Passthrough(PipelineNode):
__human_name__ = "Do Nothing"

def _process(self, input, state: PipelineState) -> PipelineState:
def _process(self, input, state: PipelineState, node_id: str) -> PipelineState:
self.logger.debug(f"Returning input: '{input}' without modification", input=input, output=input)
return input

Expand Down Expand Up @@ -211,7 +263,7 @@ def _prompt_chain(self, reference_data):
def extraction_chain(self, json_schema, reference_data):
return self._prompt_chain(reference_data) | super().get_chat_model().with_structured_output(json_schema)

def _process(self, input, state: PipelineState) -> RunnableLambda:
def _process(self, input, state: PipelineState, **kwargs) -> PipelineState:
json_schema = self.to_json_schema(json.loads(self.data_schema))
reference_data = self.get_reference_data(state)
prompt_token_count = self._get_prompt_token_count(reference_data, json_schema)
Expand Down
5 changes: 4 additions & 1 deletion apps/pipelines/nodes/types.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from typing_extensions import TypeAliasType

HistoryName = TypeAliasType("HistoryName", str)
HistoryType = TypeAliasType("HistoryType", str)
Keywords = TypeAliasType("Keywords", list)
LlmProviderId = TypeAliasType("LlmProviderId", int)
LlmModel = TypeAliasType("LlmModel", str)
LlmProviderId = TypeAliasType("LlmProviderId", int)
LlmTemperature = TypeAliasType("LlmTemperature", float)
MaxTokenLimit = TypeAliasType("MaxTokenLimit", int)
NumOutputs = TypeAliasType("NumOutputs", int)
PipelineJinjaTemplate = TypeAliasType("PipelineJinjaTemplate", str)
Prompt = TypeAliasType("Prompt", str)
Expand Down
Loading

0 comments on commit b38dc02

Please sign in to comment.