From 4534b1dbd7c760497645f795a993a132f6b759f2 Mon Sep 17 00:00:00 2001 From: w5688414 Date: Thu, 18 Jan 2024 06:24:53 +0000 Subject: [PATCH] remove prompt agent --- .../examples/automatic_prompt_engineer.py | 125 ------------------ erniebot-agent/examples/text_summarization.py | 61 --------- erniebot-agent/examples/utils.py | 120 ----------------- .../src/erniebot_agent/agents/prompt_agent.py | 54 -------- .../src/erniebot/resources/chat_completion.py | 1 + 5 files changed, 1 insertion(+), 360 deletions(-) delete mode 100644 erniebot-agent/examples/automatic_prompt_engineer.py delete mode 100644 erniebot-agent/examples/text_summarization.py delete mode 100644 erniebot-agent/examples/utils.py delete mode 100644 erniebot-agent/src/erniebot_agent/agents/prompt_agent.py diff --git a/erniebot-agent/examples/automatic_prompt_engineer.py b/erniebot-agent/examples/automatic_prompt_engineer.py deleted file mode 100644 index 685c197c0..000000000 --- a/erniebot-agent/examples/automatic_prompt_engineer.py +++ /dev/null @@ -1,125 +0,0 @@ -import argparse -import asyncio - -from erniebot_agent.agents.prompt_agent import PromptAgent -from erniebot_agent.chat_models import ERNIEBot -from erniebot_agent.memory import WholeMemory -from erniebot_agent.tools.openai_search_tool import OpenAISearchTool -from langchain.embeddings.openai import OpenAIEmbeddings -from langchain.vectorstores import FAISS -from prettytable import PrettyTable -from utils import ( - create_description, - create_keywords, - create_questions, - erniebot_chat, - read_data, -) - -import erniebot - -# yapf: disable -parser = argparse.ArgumentParser() -parser.add_argument("--api_type", default=None, type=str, help="The API Key.") -parser.add_argument("--access_token", default=None, type=str, help="The secret key.") -parser.add_argument("--summarization_path", default='data/data.jsonl', type=str, help="The output path.") -parser.add_argument("--number_of_prompts", default=3, type=int, help="The number of tool descriptions.") -parser.add_argument("--num_questions", default=3, type=int, help="The number of few shot questions.") -parser.add_argument("--num_keywords", default=-1, type=int, help="The number of few shot questions.") -args = parser.parse_args() -# yapf: enable - - -def generate_candidate_prompts(description, number_of_prompts): - prompts = [] - for i in range(number_of_prompts): - messages = [create_description(description)] - results = erniebot_chat( - messages, model="ernie-bot-4", api_type=args.api_type, access_token=args.access_token - )["result"] - prompts.append(results) - return prompts - - -def generate_candidate_questions( - description, num_questions: int = -1, num_keywords: int = -1, temperature=1e-10 -): - if num_questions > 0: - messages = [create_questions(description, num_questions=num_questions)] - elif num_keywords > 0: - messages = [create_keywords(description, num_keywords=num_keywords)] - - results = erniebot_chat( - messages, api_type=args.api_type, access_token=args.access_token, temperature=temperature - )["result"] - return results - - -if __name__ == "__main__": - erniebot.api_type = args.api_type - erniebot.access_token = args.access_token - - embeddings = OpenAIEmbeddings(deployment="text-embedding-ada") - faiss = FAISS.load_local("城市管理执法办法", embeddings) - - list_data = read_data(args.summarization_path) - doc = list_data[0] - tool_descriptions = generate_candidate_prompts(doc["abstract"], number_of_prompts=args.number_of_prompts) - tool_descriptions = list(set(tool_descriptions)) - print(tool_descriptions) - - questions = generate_candidate_questions(doc["abstract"], num_questions=args.num_questions).split("\n") - - if args.num_keywords > 0: - keywords = generate_candidate_questions(doc["abstract"], num_keywords=args.num_keywords).split("\n") - questions += keywords - - prompts = tool_descriptions - prompt_results = {prompt: {"correct": 0.0, "total": 0.0} for prompt in prompts} - - # Initialize the table - table = PrettyTable() - table_field_names = ["Prompt"] + [ - f"question {i+1}-{j+1}" for j, prompt in enumerate(questions) for i in range(questions.count(prompt)) - ] - table.field_names = table_field_names - - # Wrap the text in the "Prompt" column - table.max_width["Prompt"] = 100 - - llm = ERNIEBot(model="ernie-bot") - best_prompt = None - best_percentage = 0.0 - for i, tool_description in enumerate(tool_descriptions): - openai_city_management = OpenAISearchTool( - name="city_administrative_law_enforcement", - description=tool_description, - db=faiss, - threshold=0.1, - ) - row = [tool_description] - resps = [] - for query in questions: - agent = PromptAgent(memory=WholeMemory(), llm=llm, tools=[openai_city_management]) - response = asyncio.run(agent.async_run(query)) - resps.append(response) - if response is True: - prompt_results[tool_description]["correct"] += 1 - row.append("✅") - else: - row.append("❌") - prompt_results[tool_description]["total"] += 1 - table.add_row(row) - - print(f"生成的问题如下:{questions}") - print(table) - for i, prompt in enumerate(prompts): - correct = prompt_results[prompt]["correct"] - total = prompt_results[prompt]["total"] - percentage = (correct / total) * 100 - print(f"Prompt {i+1} got {percentage:.2f}% correct.") - if percentage > best_percentage: - best_percentage = percentage - best_prompt = tool_description - - print(f"The best prompt was '{best_prompt}' with a correctness of {best_percentage:.2f}%.") diff --git a/erniebot-agent/examples/text_summarization.py b/erniebot-agent/examples/text_summarization.py deleted file mode 100644 index 9f4024707..000000000 --- a/erniebot-agent/examples/text_summarization.py +++ /dev/null @@ -1,61 +0,0 @@ -import argparse -import os - -import jsonlines -from utils import create_abstract, erniebot_chat, read_data, split_text - -# yapf: disable -parser = argparse.ArgumentParser() -parser.add_argument("--api_type", default=None, type=str, help="The API Key.") -parser.add_argument("--access_token", default=None, type=str, help="The secret key.") -parser.add_argument("--data_path", default='data/json_data.jsonl', type=str, help="The data path.") -parser.add_argument("--output_path", default='data/finance_abstract', type=str, help="The output path.") -parser.add_argument('--chatbot_type', choices=['erniebot'], default="erniebot", - help="The chatbot model types") -args = parser.parse_args() -# yapf: enable - - -def summarize_text(text: str): - summaries = [] - - chunks = list(split_text(text, max_length=4096)) - print(f"Summarizing text with total chunks: {len(chunks)}") - for i, chunk in enumerate(chunks): - messages = [create_abstract(chunk)] - summary = erniebot_chat(messages, api_type=args.api_type, access_token=args.access_token).rbody[ - "result" - ] - print(summary) - summaries.append(summary) - - combined_summary = "\n".join(summaries) - combined_summary = combined_summary[:7000] - messages = [create_abstract(combined_summary)] - - final_summary = erniebot_chat(messages, api_type=args.api_type, access_token=args.access_token).rbody[ - "result" - ] - print("Final summary length: ", len(final_summary)) - print(final_summary) - return final_summary - - -def generate_summary_jsonl(): - os.makedirs(args.output_path, exist_ok=True) - list_data = read_data(args.data_path) - for md_file in list_data: - markdown_text = md_file["content"] - summary = summarize_text(markdown_text) - md_file["abstract"] = summary - - output_json = f"{args.output_path}/data.jsonl" - with jsonlines.open(output_json, "w") as f: - for item in list_data: - f.write(item) - return output_json - - -if __name__ == "__main__": - # text summarization - generate_summary_jsonl() diff --git a/erniebot-agent/examples/utils.py b/erniebot-agent/examples/utils.py deleted file mode 100644 index 25a373c56..000000000 --- a/erniebot-agent/examples/utils.py +++ /dev/null @@ -1,120 +0,0 @@ -from typing import Dict, Generator - -import jsonlines - -import erniebot - - -def read_data(json_path): - list_data = [] - with jsonlines.open(json_path, "r") as f: - for item in f: - list_data.append(item) - return list_data - - -def create_abstract(chunk: str) -> Dict[str, str]: - """Create a message for the chat completion - - Args: - chunk (str): The chunk of text to summarize - question (str): The question to answer - - Returns: - Dict[str, str]: The message to send to the chat completion - """ - return { - "role": "user", - "content": f"""{chunk},请用中文对上述文章进行总结,总结需要有概括性,不允许输出与文章内容无关的信息,字数控制在500字以内。""", - } - - -def create_questions(chunk: str, num_questions: int = 5): - return { - "role": "user", - "content": f"""{chunk},请根据上面的摘要,生成{num_questions}个问题,问题内容和形式要多样化,口语化,不允许重复,分条列举出来.""", - } - - -def create_keywords(chunk: str, num_keywords: int = 3): - return { - "role": "user", - "content": f"""{chunk},请根据上面的摘要,抽取{num_keywords}个关键字或者简短关键句,分条列举出来。 - 要求:只需要输出关键字或者简短的关键句,不需要输出其它的内容.""", - } - - -def create_description(chunk: str) -> Dict[str, str]: - """Create a message for the chat completion - - Args: - chunk (str): The chunk of text to summarize - question (str): The question to answer - - Returns: - Dict[str, str]: The message to send to the chat completion - """ - return { - "role": "user", - "content": f"""{chunk},请根据上面的摘要,生成一个简短的描述,不超过40字.""", - } - - -def split_text(text: str, max_length: int = 8192) -> Generator[str, None, None]: - """Split text into chunks of a maximum length - - Args: - text (str): The text to split - max_length (int, optional): The maximum length of each chunk. Defaults to 8192. - - Yields: - str: The next chunk of text - - Raises: - ValueError: If the text is longer than the maximum length - """ - paragraphs = text.split("\n") - current_length = 0 - current_chunk = [] - - for paragraph in paragraphs: - if current_length + len(paragraph) + 1 <= max_length: - current_chunk.append(paragraph) - current_length += len(paragraph) + 1 - else: - yield "\n".join(current_chunk) - current_chunk = [paragraph] - current_length = len(paragraph) + 1 - - if current_chunk: - yield "\n".join(current_chunk) - - -def erniebot_chat( - messages, model="ernie-bot", api_type="aistudio", access_token=None, functions=None, **kwargs -): - """ - Args: - messages: dict or list, 输入的消息(message) - model: str, 模型名称 - api_type: str, 接口类型,可选值包括 'aistudio' 和 'qianfan' - access_token: str, 访问令牌(access token) - functions: list, 函数列表 - kwargs: 其他参数 - - Returns: - dict or list, 返回聊天结果 - """ - _config = dict( - api_type=api_type, - access_token=access_token, - ) - if functions is None: - resp_stream = erniebot.ChatCompletion.create( - _config_=_config, model=model, messages=messages, **kwargs - ) - else: - resp_stream = erniebot.ChatCompletion.create( - _config_=_config, model=model, messages=messages, **kwargs, functions=functions - ) - return resp_stream diff --git a/erniebot-agent/src/erniebot_agent/agents/prompt_agent.py b/erniebot-agent/src/erniebot_agent/agents/prompt_agent.py deleted file mode 100644 index 648f2b42d..000000000 --- a/erniebot-agent/src/erniebot_agent/agents/prompt_agent.py +++ /dev/null @@ -1,54 +0,0 @@ -import json -from typing import List, Optional, Sequence - -from erniebot_agent.agents.agent import Agent -from erniebot_agent.agents.schema import AgentResponse, AgentStep, File -from erniebot_agent.memory.messages import HumanMessage, Message -from erniebot_agent.tools.base import BaseTool - - -class PromptAgent(Agent): - async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> AgentResponse: - chat_history: List[Message] = [] - steps_taken: List[AgentStep] = [] - run_input = await HumanMessage.create_with_files( - prompt, files or [], include_file_urls=self.file_needs_url - ) - chat_history.append(run_input) - msg = await self._step(chat_history) - text = json.dumps({"msg": msg}, ensure_ascii=False) - response = self._create_stopped_response(chat_history, steps_taken, message=text) - return response - - async def _step(self, chat_history: List[Message], selected_tool: Optional[BaseTool] = None) -> bool: - new_messages: List[Message] = [] - input_messages = self.memory.get_messages() + chat_history - if selected_tool is not None: - tool_choice = {"type": "function", "function": {"name": selected_tool.tool_name}} - llm_resp = await self.run_llm( - messages=input_messages, - functions=[selected_tool.function_call_schema()], # only regist one tool - tool_choice=tool_choice, - ) - else: - llm_resp = await self.run_llm(messages=input_messages) - - output_message = llm_resp.message # AIMessage - new_messages.append(output_message) - if output_message.function_call is not None: - return True - else: - return False - - def _create_stopped_response( - self, - chat_history: List[Message], - steps: List[AgentStep], - message: str, - ) -> AgentResponse: - return AgentResponse( - text=message, - chat_history=chat_history, - steps=steps, - status="STOPPED", - ) diff --git a/erniebot/src/erniebot/resources/chat_completion.py b/erniebot/src/erniebot/resources/chat_completion.py index a4ef00c7f..e626b690c 100644 --- a/erniebot/src/erniebot/resources/chat_completion.py +++ b/erniebot/src/erniebot/resources/chat_completion.py @@ -515,6 +515,7 @@ def _set_val_if_key_exists(src: dict, dst: dict, key: str) -> None: # stream stream = kwargs.get("stream", False) + return RequestWithStream( method="POST", path=path,