Skip to content

Commit

Permalink
fix(summarize): Use a better algorithm for chunked summaries
Browse files Browse the repository at this point in the history
Signed-off-by: Marcel Klehr <[email protected]>
  • Loading branch information
marcelklehr committed Dec 12, 2024
1 parent f138ce8 commit c873355
Showing 1 changed file with 29 additions and 56 deletions.
85 changes: 29 additions & 56 deletions lib/summarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,20 @@
"""A recursive summarize chain
"""

from typing import Any, Optional
from typing import Any

from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.prompts import PromptTemplate
from langchain.schema.prompt_template import BasePromptTemplate
from langchain.text_splitter import CharacterTextSplitter
from langchain.chains import LLMChain
from langchain_core.runnables import Runnable


class SummarizeChain(Chain):
class SummarizeProcessor:
"""
A summarization chain
"""

system_prompt = "You're an AI assistant tasked with summarizing the text given to you by the user."
system_prompt: str = "You're an AI assistant tasked with summarizing the text given to you by the user."
user_prompt: BasePromptTemplate = PromptTemplate(
input_variables=["text"],
template="""
Expand All @@ -30,63 +27,39 @@ class SummarizeChain(Chain):
"
Output only the summary without quotes, nothing else, especially no introductory or explanatory text. Also, do not mention the language you used explicitly.
Here is your summary in the same language as the text:
Here is your summary in the same language as the original text:
"""
)


llm_chain: LLMChain
runnable: Runnable
n_ctx: int = 8000
output_key: str = "text" #: :meta private:

class Config:
"""Configuration for this pydantic object."""
def __init__(self, runnable: Runnable, n_ctx: int = 8000):
self.runnable = runnable
self.n_ctx = n_ctx

extra = 'forbid'
arbitrary_types_allowed = True

@property
def input_keys(self) -> list[str]:
"""Will be whatever keys the prompt expects.
:meta private:
"""
return ['input']

@property
def output_keys(self) -> list[str]:
"""Will always return text key.
:meta private:
"""
return [self.output_key]

def _call(
def __call__(
self,
inputs: dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> dict[str, str]:

if not {"user_prompt", "system_prompt"} == set(self.llm_chain.input_keys):
raise ValueError("llm_chain must have input_keys ['user_prompt', 'system_prompt']")
if not self.llm_chain.output_keys == [self.output_key]:
raise ValueError(f"llm_chain must have output_keys [{self.output_key}]")

summary_size = max(len(inputs['input']) * 0.2, 1000) # 2000 chars summary per 10.000 chars original text
chunk_size = max(self.n_ctx - summary_size, 2048)
) -> dict[str, Any]:
chunk_size = max(self.n_ctx * 0.7, 2048)

text_splitter = CharacterTextSplitter(
separator='\n\n|\\.|\\?|\\!', chunk_size=chunk_size, chunk_overlap=0, keep_separator=True)
texts = text_splitter.split_text(inputs['input'])
i = 0
while i == 0 or sum([len(text) for text in texts]) > summary_size:
docs = [texts[i:i + 3] for i in range(0, len(texts), 3)]
outputs = self.llm_chain.apply([{"user_prompt": self.user_prompt.format_prompt(text=''.join(doc)), "system_prompt": self.system_prompt} for doc in docs])
texts = [output[self.output_key] for output in outputs]
i += 1

return {self.output_key: '\n\n'.join(texts)}

@property
def _chain_type(self) -> str:
return "summarize_chain"
separator='\n\n|\\.|\\?|\\!', is_separator_regex=True, chunk_size=chunk_size*4, chunk_overlap=0, keep_separator=True)
chunks = text_splitter.split_text(inputs['input'])
print([len(chunk) for chunk in chunks])
new_num_chunks = len(chunks)
# first iteration outside of while loop
old_num_chunks = new_num_chunks
summaries = [self.runnable.invoke({"user_prompt": self.user_prompt.format_prompt(text=''.join(chunk)), "system_prompt": self.system_prompt}) for chunk in chunks]
chunks = text_splitter.split_text('\n\n'.join(summaries))
new_num_chunks = len(chunks)
while (old_num_chunks > new_num_chunks):
# now comes the while loop body
old_num_chunks = new_num_chunks
summaries = [self.runnable.invoke({"user_prompt": self.user_prompt.format_prompt(text=''.join(chunk)), "system_prompt": self.system_prompt}) for chunk in chunks]
chunks = text_splitter.split_text('\n\n'.join(summaries))
new_num_chunks = len(chunks)

return {'output': '\n\n'.join(summaries)}

0 comments on commit c873355

Please sign in to comment.