From a4253ae0b618b186290b693e94a86dec881018a2 Mon Sep 17 00:00:00 2001 From: mrzaizai2k Date: Mon, 22 Jul 2024 15:11:04 +0700 Subject: [PATCH] change from gpt3.5 to got 4o mini --- src/summarize_text.py | 52 +++++++++++++++++++++++++++---------------- 1 file changed, 33 insertions(+), 19 deletions(-) diff --git a/src/summarize_text.py b/src/summarize_text.py index c41706d..d90d080 100644 --- a/src/summarize_text.py +++ b/src/summarize_text.py @@ -31,8 +31,12 @@ # from langchain.chains import LLMChain # from langchain.prompts import PromptTemplate # from ctransformers import AutoModelForCausalLM, AutoTokenizer -from langchain_openai import OpenAI +from langchain_openai import ChatOpenAI + from datetime import datetime +from langchain_core.prompts import PromptTemplate +from langchain_core.output_parsers import CommaSeparatedListOutputParser + # from langchain_community.embeddings import HuggingFaceEmbeddings from src.Utils.utils import convert_m4a_to_mp3, SpellCheck @@ -44,24 +48,33 @@ def __init__(self, template_path:str = 'config/seperate_task_template.txt', ): self.template_path = template_path with open(self.template_path, 'r') as file: self.template = file.read() - self.load_llm() - # self.prompt = PromptTemplate(template=self.template, input_variables=['date', 'question']) - # self.llm_chain = LLMChain(prompt=self.prompt, llm=self.llm) + + self.OPENAI_API_KEY = os.getenv('OPENAI_API_KEY') + self.llm = self.load_llm() + self.output_parser = CommaSeparatedListOutputParser() def load_llm(self): - OPENAI_API_KEY = os.getenv('OPENAI_API_KEY') - self.llm = OpenAI(model="gpt-3.5-turbo-instruct", - openai_api_key=OPENAI_API_KEY, - max_tokens = 512, - temperature=0.7, - ) - + llm = ChatOpenAI(model="gpt-4o-mini", + openai_api_key=self.OPENAI_API_KEY, + max_tokens = 512, + temperature=0.7, + ) + return llm + def get_response(self, text) -> list: - - tmp_prompt = f"{self.template}\n Now at {datetime.now()}, please break down this text: {text}" - response = self.llm(tmp_prompt) - # print('raw', response) - response = self.get_tasks_from_string(response) + prompt = PromptTemplate.from_template( + "{template}\nNow at {current_time}, please break down this text: {input_text}\n{format_instructions}" + ) + + formatted_prompt = prompt.format( + template=self.template, + current_time=datetime.now(), + input_text=text, + format_instructions=self.output_parser.get_format_instructions() + ) + + response = self.llm.invoke(formatted_prompt) + response = self.get_tasks_from_string(response.content) return response def get_tasks_from_string(self, text:str) -> list: @@ -69,6 +82,8 @@ def get_tasks_from_string(self, text:str) -> list: start_index = text.find('[') end_index = text.find(']') + 1 json_string = text[start_index:end_index] + json_string = json_string.replace('true', 'True').replace('false', 'False').replace('null', 'None') + tasks_list = eval(json_string) return tasks_list @@ -439,7 +454,7 @@ def segment_text(self, result) -> str: segmented_text = self._post_processing_transcribed_text(segmented_text) try: - # Send the joined text to self.task_seperator.get_response() to get a list + # Send the joined text to self.task_seperator.get_response() to get a list self.response_list = self.task_seperator.get_response(text=segmented_text) # Separate the received list by newline titles = [task['title'] for task in self.response_list] @@ -477,11 +492,10 @@ def get_task_list(self): if __name__ == "__main__": - audio_path='data/audio_2.ogg' speech_to_text = SpeechSummaryProcessor(audio_path=audio_path) text = speech_to_text.generate_speech_to_text() - print ('Text', text) + print ('Text:', text) # symbol = 'SSI' # date_format='year'