From c8733559515bec250c518e7dadc76323a6a4e7a2 Mon Sep 17 00:00:00 2001 From: Marcel Klehr Date: Thu, 12 Dec 2024 10:05:25 +0100 Subject: [PATCH] fix(summarize): Use a better algorithm for chunked summaries Signed-off-by: Marcel Klehr --- lib/summarize.py | 85 +++++++++++++++++------------------------------- 1 file changed, 29 insertions(+), 56 deletions(-) diff --git a/lib/summarize.py b/lib/summarize.py index 7f2ca99..5bc8762 100644 --- a/lib/summarize.py +++ b/lib/summarize.py @@ -3,23 +3,20 @@ """A recursive summarize chain """ -from typing import Any, Optional +from typing import Any -from langchain.base_language import BaseLanguageModel -from langchain.callbacks.manager import CallbackManagerForChainRun -from langchain.chains.base import Chain from langchain.prompts import PromptTemplate from langchain.schema.prompt_template import BasePromptTemplate from langchain.text_splitter import CharacterTextSplitter -from langchain.chains import LLMChain +from langchain_core.runnables import Runnable -class SummarizeChain(Chain): +class SummarizeProcessor: """ A summarization chain """ - system_prompt = "You're an AI assistant tasked with summarizing the text given to you by the user." + system_prompt: str = "You're an AI assistant tasked with summarizing the text given to you by the user." user_prompt: BasePromptTemplate = PromptTemplate( input_variables=["text"], template=""" @@ -30,63 +27,39 @@ class SummarizeChain(Chain): " Output only the summary without quotes, nothing else, especially no introductory or explanatory text. Also, do not mention the language you used explicitly. -Here is your summary in the same language as the text: +Here is your summary in the same language as the original text: """ ) - llm_chain: LLMChain + runnable: Runnable n_ctx: int = 8000 - output_key: str = "text" #: :meta private: - class Config: - """Configuration for this pydantic object.""" + def __init__(self, runnable: Runnable, n_ctx: int = 8000): + self.runnable = runnable + self.n_ctx = n_ctx - extra = 'forbid' - arbitrary_types_allowed = True - - @property - def input_keys(self) -> list[str]: - """Will be whatever keys the prompt expects. - - :meta private: - """ - return ['input'] - - @property - def output_keys(self) -> list[str]: - """Will always return text key. - - :meta private: - """ - return [self.output_key] - - def _call( + def __call__( self, inputs: dict[str, Any], - run_manager: Optional[CallbackManagerForChainRun] = None, - ) -> dict[str, str]: - - if not {"user_prompt", "system_prompt"} == set(self.llm_chain.input_keys): - raise ValueError("llm_chain must have input_keys ['user_prompt', 'system_prompt']") - if not self.llm_chain.output_keys == [self.output_key]: - raise ValueError(f"llm_chain must have output_keys [{self.output_key}]") - - summary_size = max(len(inputs['input']) * 0.2, 1000) # 2000 chars summary per 10.000 chars original text - chunk_size = max(self.n_ctx - summary_size, 2048) + ) -> dict[str, Any]: + chunk_size = max(self.n_ctx * 0.7, 2048) text_splitter = CharacterTextSplitter( - separator='\n\n|\\.|\\?|\\!', chunk_size=chunk_size, chunk_overlap=0, keep_separator=True) - texts = text_splitter.split_text(inputs['input']) - i = 0 - while i == 0 or sum([len(text) for text in texts]) > summary_size: - docs = [texts[i:i + 3] for i in range(0, len(texts), 3)] - outputs = self.llm_chain.apply([{"user_prompt": self.user_prompt.format_prompt(text=''.join(doc)), "system_prompt": self.system_prompt} for doc in docs]) - texts = [output[self.output_key] for output in outputs] - i += 1 - - return {self.output_key: '\n\n'.join(texts)} - - @property - def _chain_type(self) -> str: - return "summarize_chain" + separator='\n\n|\\.|\\?|\\!', is_separator_regex=True, chunk_size=chunk_size*4, chunk_overlap=0, keep_separator=True) + chunks = text_splitter.split_text(inputs['input']) + print([len(chunk) for chunk in chunks]) + new_num_chunks = len(chunks) + # first iteration outside of while loop + old_num_chunks = new_num_chunks + summaries = [self.runnable.invoke({"user_prompt": self.user_prompt.format_prompt(text=''.join(chunk)), "system_prompt": self.system_prompt}) for chunk in chunks] + chunks = text_splitter.split_text('\n\n'.join(summaries)) + new_num_chunks = len(chunks) + while (old_num_chunks > new_num_chunks): + # now comes the while loop body + old_num_chunks = new_num_chunks + summaries = [self.runnable.invoke({"user_prompt": self.user_prompt.format_prompt(text=''.join(chunk)), "system_prompt": self.system_prompt}) for chunk in chunks] + chunks = text_splitter.split_text('\n\n'.join(summaries)) + new_num_chunks = len(chunks) + + return {'output': '\n\n'.join(summaries)}