diff --git a/comps/agent/langchain/agent.py b/comps/agent/langchain/agent.py index b0fb1b81b..9530d7bb6 100644 --- a/comps/agent/langchain/agent.py +++ b/comps/agent/langchain/agent.py @@ -5,6 +5,8 @@ import os import pathlib import sys +from datetime import datetime +from typing import Union from fastapi.responses import StreamingResponse @@ -14,7 +16,21 @@ from comps import CustomLogger, GeneratedDoc, LLMParamsDoc, ServiceType, opea_microservices, register_microservice from comps.agent.langchain.src.agent import instantiate_agent +from comps.agent.langchain.src.global_var import assistants_global_kv, threads_global_kv +from comps.agent.langchain.src.thread import instantiate_thread_memory, thread_completion_callback from comps.agent.langchain.src.utils import get_args +from comps.cores.proto.api_protocol import ( + AssistantsObject, + ChatCompletionRequest, + CreateAssistantsRequest, + CreateMessagesRequest, + CreateRunResponse, + CreateThreadsRequest, + MessageContent, + MessageObject, + RunObject, + ThreadObject, +) logger = CustomLogger("comps-react-agent") logflag = os.getenv("LOGFLAG", False) @@ -23,14 +39,13 @@ @register_microservice( - name="opea_service@comps-react-agent", + name="opea_service@comps-chat-agent", service_type=ServiceType.LLM, endpoint="/v1/chat/completions", host="0.0.0.0", port=args.port, - input_datatype=LLMParamsDoc, ) -async def llm_generate(input: LLMParamsDoc): +async def llm_generate(input: Union[LLMParamsDoc, ChatCompletionRequest]): if logflag: logger.info(input) # 1. initialize the agent @@ -42,19 +57,149 @@ async def llm_generate(input: LLMParamsDoc): if logflag: logger.info(type(agent_inst)) + if isinstance(input, LLMParamsDoc): + # use query as input + input_query = input.query + else: + # openai compatible input + if isinstance(input.messages, str): + input_query = input.messages + else: + input_query = input.messages[-1]["content"] + # 2. prepare the input for the agent if input.streaming: print("-----------STREAMING-------------") - return StreamingResponse(agent_inst.stream_generator(input.query, config), media_type="text/event-stream") + return StreamingResponse(agent_inst.stream_generator(input_query, config), media_type="text/event-stream") else: - # TODO: add support for non-streaming mode print("-----------NOT STREAMING-------------") - response = await agent_inst.non_streaming_run(input.query, config) + response = await agent_inst.non_streaming_run(input_query, config) print("-----------Response-------------") print(response) - return GeneratedDoc(text=response, prompt=input.query) + return GeneratedDoc(text=response, prompt=input_query) + + +@register_microservice( + name="opea_service@comps-chat-agent", + endpoint="/v1/assistants", + host="0.0.0.0", + port=args.port, +) +def create_assistants(input: CreateAssistantsRequest): + # 1. initialize the agent + print("args: ", args) + agent_inst = instantiate_agent(args, args.strategy, with_memory=True) + agent_id = agent_inst.id + created_at = int(datetime.now().timestamp()) + with assistants_global_kv as g_assistants: + g_assistants[agent_id] = (agent_inst, created_at) + print(f"Record assistant inst {agent_id} in global KV") + + # get current time in string format + return AssistantsObject( + id=agent_id, + created_at=created_at, + ) + + +@register_microservice( + name="opea_service@comps-chat-agent", + endpoint="/v1/threads", + host="0.0.0.0", + port=args.port, +) +def create_threads(input: CreateThreadsRequest): + # create a memory KV for the thread + thread_inst, thread_id = instantiate_thread_memory() + created_at = int(datetime.now().timestamp()) + status = "ready" + with threads_global_kv as g_threads: + g_threads[thread_id] = (thread_inst, created_at, status) + print(f"Record thread inst {thread_id} in global KV") + + return ThreadObject( + id=thread_id, + created_at=created_at, + ) + + +@register_microservice( + name="opea_service@comps-chat-agent", + endpoint="/v1/threads/{thread_id}/messages", + host="0.0.0.0", + port=args.port, +) +def create_messages(thread_id, input: CreateMessagesRequest): + with threads_global_kv as g_threads: + thread_inst, _, _ = g_threads[thread_id] + + # create a memory KV for the message + role = input.role + if isinstance(input.content, str): + query = input.content + else: + query = input.content[-1]["text"] + msg_id, created_at = thread_inst.add_query(query) + + structured_content = MessageContent(text=query) + return MessageObject( + id=msg_id, + created_at=created_at, + thread_id=thread_id, + role=role, + content=[structured_content], + ) + + +@register_microservice( + name="opea_service@comps-chat-agent", + endpoint="/v1/threads/{thread_id}/runs", + host="0.0.0.0", + port=args.port, +) +def create_run(thread_id, input: CreateRunResponse): + with threads_global_kv as g_threads: + thread_inst, _, status = g_threads[thread_id] + + if status == "running": + return "[error] Thread is already running, need to cancel the current run or wait for it to finish" + + agent_id = input.assistant_id + with assistants_global_kv as g_assistants: + agent_inst, _ = g_assistants[agent_id] + + config = {"recursion_limit": args.recursion_limit} + input_query = thread_inst.get_query() + try: + return StreamingResponse( + thread_completion_callback(agent_inst.stream_generator(input_query, config, thread_id), thread_id), + media_type="text/event-stream", + ) + except Exception as e: + with threads_global_kv as g_threads: + thread_inst, created_at, status = g_threads[thread_id] + g_threads[thread_id] = (thread_inst, created_at, "ready") + return f"An error occurred: {e}. This thread is now set as ready" + + +@register_microservice( + name="opea_service@comps-chat-agent", + endpoint="/v1/threads/{thread_id}/runs/cancel", + host="0.0.0.0", + port=args.port, +) +def cancel_run(thread_id): + with threads_global_kv as g_threads: + thread_inst, created_at, status = g_threads[thread_id] + if status == "ready": + return "Thread is not running, no need to cancel" + elif status == "try_cancel": + return "cancel request is submitted" + else: + g_threads[thread_id] = (thread_inst, created_at, "try_cancel") + return "submit cancel request" if __name__ == "__main__": - opea_microservices["opea_service@comps-react-agent"].start() + opea_microservices["opea_service@comps-chat-agent"].start() diff --git a/comps/agent/langchain/requirements.txt b/comps/agent/langchain/requirements.txt index 1da88cb85..6e4dd1012 100644 --- a/comps/agent/langchain/requirements.txt +++ b/comps/agent/langchain/requirements.txt @@ -5,12 +5,13 @@ docarray[full] duckduckgo-search fastapi huggingface_hub==0.24.0 -langchain #==0.1.12 +langchain==0.2.9 langchain-huggingface langchain-openai -langchain_community -langchainhub +langchain_community==0.2.7 +langchainhub==0.1.20 langgraph +langsmith numpy # used by cloud native diff --git a/comps/agent/langchain/src/agent.py b/comps/agent/langchain/src/agent.py index 55a50e81d..9accf8a35 100644 --- a/comps/agent/langchain/src/agent.py +++ b/comps/agent/langchain/src/agent.py @@ -2,23 +2,23 @@ # SPDX-License-Identifier: Apache-2.0 -def instantiate_agent(args, strategy="react_langchain"): +def instantiate_agent(args, strategy="react_langchain", with_memory=False): if strategy == "react_langchain": from .strategy.react import ReActAgentwithLangchain - return ReActAgentwithLangchain(args) + return ReActAgentwithLangchain(args, with_memory) elif strategy == "react_langgraph": from .strategy.react import ReActAgentwithLanggraph - return ReActAgentwithLanggraph(args) + return ReActAgentwithLanggraph(args, with_memory) elif strategy == "plan_execute": from .strategy.planexec import PlanExecuteAgentWithLangGraph - return PlanExecuteAgentWithLangGraph(args) + return PlanExecuteAgentWithLangGraph(args, with_memory) elif strategy == "rag_agent": from .strategy.ragagent import RAGAgent - return RAGAgent(args) + return RAGAgent(args, with_memory) else: raise ValueError(f"Agent strategy: {strategy} not supported!") diff --git a/comps/agent/langchain/src/global_var.py b/comps/agent/langchain/src/global_var.py new file mode 100644 index 000000000..02a2083b8 --- /dev/null +++ b/comps/agent/langchain/src/global_var.py @@ -0,0 +1,21 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import threading + + +class ThreadSafeDict(dict): + def __init__(self, *p_arg, **n_arg): + dict.__init__(self, *p_arg, **n_arg) + self._lock = threading.Lock() + + def __enter__(self): + self._lock.acquire() + return self + + def __exit__(self, type, value, traceback): + self._lock.release() + + +assistants_global_kv = ThreadSafeDict() +threads_global_kv = ThreadSafeDict() diff --git a/comps/agent/langchain/src/strategy/base_agent.py b/comps/agent/langchain/src/strategy/base_agent.py index f9e8fed9e..ca0e12a96 100644 --- a/comps/agent/langchain/src/strategy/base_agent.py +++ b/comps/agent/langchain/src/strategy/base_agent.py @@ -1,6 +1,8 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +from uuid import uuid4 + from ..tools import get_tools_descriptions from ..utils import setup_llm @@ -10,6 +12,8 @@ def __init__(self, args) -> None: self.llm_endpoint = setup_llm(args) self.tools_descriptions = get_tools_descriptions(args.tools) self.app = None + self.memory = None + self.id = f"assistant_{self.__class__.__name__}_{uuid4()}" print(self.tools_descriptions) def compile(self): diff --git a/comps/agent/langchain/src/strategy/planexec/planner.py b/comps/agent/langchain/src/strategy/planexec/planner.py index 601e28bf8..c33c906f4 100644 --- a/comps/agent/langchain/src/strategy/planexec/planner.py +++ b/comps/agent/langchain/src/strategy/planexec/planner.py @@ -16,9 +16,11 @@ from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.utils.json import parse_partial_json from langchain_huggingface import ChatHuggingFace +from langgraph.checkpoint.memory import MemorySaver from langgraph.graph import END, START, StateGraph from langgraph.graph.message import add_messages +from ...global_var import threads_global_kv from ...utils import has_multi_tool_inputs, tool_renderer from ..base_agent import BaseAgent from .prompt import ( @@ -221,7 +223,7 @@ def __call__(self, state): class PlanExecuteAgentWithLangGraph(BaseAgent): - def __init__(self, args): + def __init__(self, args, with_memory=False): super().__init__(args) # Define Node @@ -231,37 +233,39 @@ def __init__(self, args): execute_step = Executor(self.llm_endpoint, args.model, self.tools_descriptions) make_answer = AnswerMaker(self.llm_endpoint, args.model) - # answer_checker = FinalAnswerChecker(self.llm_endpoint, args.model) - # replan_step = Replanner(self.llm_endpoint, args.model, answer_checker) - # Define Graph workflow = StateGraph(PlanExecute) workflow.add_node("planner", plan_step) workflow.add_node("plan_executor", execute_step) workflow.add_node("answer_maker", make_answer) - # workflow.add_node("replan", replan_step) # Define edges workflow.add_edge(START, "planner") workflow.add_edge("planner", "plan_executor") workflow.add_edge("plan_executor", "answer_maker") workflow.add_edge("answer_maker", END) - # workflow.add_conditional_edges( - # "answer_maker", - # answer_checker, - # {END: END, "replan": "replan"}, - # ) - # workflow.add_edge("replan", "plan_executor") - # Finally, we compile it! - self.app = workflow.compile() + if with_memory: + self.app = workflow.compile(checkpointer=MemorySaver()) + else: + self.app = workflow.compile() def prepare_initial_state(self, query): return {"messages": [("user", query)]} - async def stream_generator(self, query, config): + async def stream_generator(self, query, config, thread_id=None): initial_state = self.prepare_initial_state(query) + if thread_id is not None: + config["configurable"] = {"thread_id": thread_id} async for event in self.app.astream(initial_state, config=config): + if thread_id is not None: + with threads_global_kv as g_threads: + thread_inst, created_at, status = g_threads[thread_id] + if status == "try_cancel": + yield "[thread_completion_callback] signal to cancel! Changed status to ready" + print("[thread_completion_callback] signal to cancel! Changed status to ready") + g_threads[thread_id] = (thread_inst, created_at, "ready") + break for node_name, node_state in event.items(): yield f"--- CALL {node_name} ---\n" for k, v in node_state.items(): diff --git a/comps/agent/langchain/src/strategy/ragagent/planner.py b/comps/agent/langchain/src/strategy/ragagent/planner.py index 1a60f3ca2..e618aed80 100644 --- a/comps/agent/langchain/src/strategy/ragagent/planner.py +++ b/comps/agent/langchain/src/strategy/ragagent/planner.py @@ -11,6 +11,7 @@ from langchain_core.pydantic_v1 import BaseModel, Field from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint from langchain_openai import ChatOpenAI +from langgraph.checkpoint.memory import MemorySaver from langgraph.graph import END, START, StateGraph from langgraph.graph.message import add_messages from langgraph.prebuilt import ToolNode, tools_condition @@ -154,7 +155,7 @@ def __call__(self, state): class RAGAgent(BaseAgent): - def __init__(self, args): + def __init__(self, args, with_memory=False): super().__init__(args) # Define Nodes @@ -195,7 +196,10 @@ def __init__(self, args): ) workflow.add_edge("generate", END) - self.app = workflow.compile() + if with_memory: + self.app = workflow.compile(checkpointer=MemorySaver()) + else: + self.app = workflow.compile() def should_retry(self, state): # first check how many retry attempts have been made diff --git a/comps/agent/langchain/src/strategy/react/planner.py b/comps/agent/langchain/src/strategy/react/planner.py index 93d185493..4466a115f 100644 --- a/comps/agent/langchain/src/strategy/react/planner.py +++ b/comps/agent/langchain/src/strategy/react/planner.py @@ -3,18 +3,22 @@ from langchain.agents import AgentExecutor from langchain.agents import create_react_agent as create_react_langchain_agent +from langchain.memory import ChatMessageHistory from langchain_core.messages import HumanMessage +from langchain_core.runnables.history import RunnableWithMessageHistory from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint from langchain_openai import ChatOpenAI +from langgraph.checkpoint.memory import MemorySaver from langgraph.prebuilt import create_react_agent +from ...global_var import threads_global_kv from ...utils import has_multi_tool_inputs, tool_renderer from ..base_agent import BaseAgent from .prompt import REACT_SYS_MESSAGE, hwchase17_react_prompt class ReActAgentwithLangchain(BaseAgent): - def __init__(self, args): + def __init__(self, args, with_memory=False): super().__init__(args) prompt = hwchase17_react_prompt if has_multi_tool_inputs(self.tools_descriptions): @@ -26,13 +30,41 @@ def __init__(self, args): self.app = AgentExecutor( agent=agent_chain, tools=self.tools_descriptions, verbose=True, handle_parsing_errors=True ) + self.memory = {} + + def get_session_history(session_id): + if session_id in self.memory: + return self.memory[session_id] + else: + mem = ChatMessageHistory() + self.memory[session_id] = mem + return mem + + if with_memory: + self.app = RunnableWithMessageHistory( + self.app, + get_session_history, + input_messages_key="input", + history_messages_key="chat_history", + history_factory_config=[], + ) def prepare_initial_state(self, query): return {"input": query} - async def stream_generator(self, query, config): + async def stream_generator(self, query, config, thread_id=None): initial_state = self.prepare_initial_state(query) + if thread_id is not None: + config["configurable"] = {"session_id": thread_id} async for chunk in self.app.astream(initial_state, config=config): + if thread_id is not None: + with threads_global_kv as g_threads: + thread_inst, created_at, status = g_threads[thread_id] + if status == "try_cancel": + yield "[thread_completion_callback] signal to cancel! Changed status to ready" + print("[thread_completion_callback] signal to cancel! Changed status to ready") + g_threads[thread_id] = (thread_inst, created_at, "ready") + break if "actions" in chunk: for action in chunk["actions"]: yield f"Calling Tool: `{action.tool}` with input `{action.tool_input}`\n\n" @@ -50,7 +82,7 @@ async def stream_generator(self, query, config): class ReActAgentwithLanggraph(BaseAgent): - def __init__(self, args): + def __init__(self, args, with_memory=False): super().__init__(args) if isinstance(self.llm_endpoint, HuggingFaceEndpoint): @@ -60,7 +92,12 @@ def __init__(self, args): tools = self.tools_descriptions - self.app = create_react_agent(self.llm, tools=tools, state_modifier=REACT_SYS_MESSAGE) + if with_memory: + self.app = create_react_agent( + self.llm, tools=tools, state_modifier=REACT_SYS_MESSAGE, checkpointer=MemorySaver() + ) + else: + self.app = create_react_agent(self.llm, tools=tools, state_modifier=REACT_SYS_MESSAGE) def prepare_initial_state(self, query): return {"messages": [HumanMessage(content=query)]} diff --git a/comps/agent/langchain/src/thread.py b/comps/agent/langchain/src/thread.py new file mode 100644 index 000000000..441a2936e --- /dev/null +++ b/comps/agent/langchain/src/thread.py @@ -0,0 +1,43 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from collections import deque +from datetime import datetime +from uuid import uuid4 + +from .global_var import threads_global_kv + + +class ThreadMemory: + def __init__(self): + self.query_list = deque() + + def add_query(self, query): + msg_id = f"msg_{uuid4()}" + created_at = int(datetime.now().timestamp()) + + self.query_list.append((query, msg_id, created_at)) + + return msg_id, created_at + + def get_query(self): + query, _, _ = self.query_list.pop() + return query + + +async def thread_completion_callback(content, thread_id): + with threads_global_kv as g_threads: + thread_inst, created_at, _ = g_threads[thread_id] + g_threads[thread_id] = (thread_inst, created_at, "running") + print("[thread_completion_callback] Changed status to running") + async for chunk in content: + if "data: [DONE]\n\n" == chunk: + with threads_global_kv as g_threads: + thread_inst, created_at, _ = g_threads[thread_id] + g_threads[thread_id] = (thread_inst, created_at, "ready") + yield chunk + + +def instantiate_thread_memory(args=None): + thread_id = f"thread_{uuid4()}" + return ThreadMemory(), thread_id diff --git a/comps/agent/langchain/test.py b/comps/agent/langchain/test.py index d3f5d4506..cb7cc0424 100644 --- a/comps/agent/langchain/test.py +++ b/comps/agent/langchain/test.py @@ -5,6 +5,7 @@ import json import os import traceback +from time import sleep import pandas as pd import requests @@ -85,6 +86,68 @@ def process_request(query): df.to_csv(os.path.join(args.filedir, args.output), index=False) +def test_assistants_http(args): + proxies = {"http": ""} + ip_addr = args.ip_addr + url = f"http://{ip_addr}:9090/v1" + + def process_request(api, query, is_stream=False): + content = json.dumps(query) if query is not None else None + print(f"send request to {url}/{api}, data is {content}") + try: + resp = requests.post(url=f"{url}/{api}", data=content, proxies=proxies, stream=is_stream) + if not is_stream: + ret = resp.json() + print(ret) + else: + for line in resp.iter_lines(decode_unicode=True): + print(line) + ret = None + + resp.raise_for_status() # Raise an exception for unsuccessful HTTP status codes + return ret + except requests.exceptions.RequestException as e: + ret = f"An error occurred:{e}" + print(ret) + return False + + # step 1. create assistants + query = {} + if ret := process_request("assistants", query): + assistant_id = ret.get("id") + print("Created Assistant Id: ", assistant_id) + else: + print("Error when creating assistants !!!!") + return + + # step 2. create threads + query = {} + if ret := process_request("threads", query): + thread_id = ret.get("id") + print("Created Thread Id: ", thread_id) + else: + print("Error when creating threads !!!!") + return + + # step 3. add messages + if args.query is None: + query = {"role": "user", "content": "How old was Bill Gates when he built Microsoft?"} + else: + query = {"role": "user", "content": args.query} + if ret := process_request(f"threads/{thread_id}/messages", query): + pass + else: + print("Error when add messages !!!!") + return + + # step 4. run + print("You may cancel the running process with cmdline") + print(f"curl {url}/threads/{thread_id}/runs/cancel -X POST -H 'Content-Type: application/json'") + + query = {"assistant_id": assistant_id} + process_request(f"threads/{thread_id}/runs", query, is_stream=True) + + def test_ut(args): from src.tools import get_tools_descriptions @@ -99,8 +162,10 @@ def test_ut(args): parser.add_argument("--strategy", type=str, default="react") parser.add_argument("--local_test", action="store_true", help="Test with local mode") parser.add_argument("--endpoint_test", action="store_true", help="Test with endpoint mode") + parser.add_argument("--assistants_api_test", action="store_true", help="Test with endpoint mode") parser.add_argument("--q", type=int, default=0) parser.add_argument("--ip_addr", type=str, default="127.0.0.1", help="endpoint ip address") + parser.add_argument("--query", type=str, default=None) parser.add_argument("--filedir", type=str, default="./", help="test file directory") parser.add_argument("--filename", type=str, default="query.csv", help="query_list_file") parser.add_argument("--output", type=str, default="output.csv", help="query_list_file") @@ -117,5 +182,7 @@ def test_ut(args): test_agent_http(args) elif args.ut: test_ut(args) + elif args.assistants_api_test: + test_assistants_http(args) else: print("Please specify the test type") diff --git a/comps/agent/langchain/test_assistant_api.py b/comps/agent/langchain/test_assistant_api.py new file mode 100644 index 000000000..cf398c8bb --- /dev/null +++ b/comps/agent/langchain/test_assistant_api.py @@ -0,0 +1,97 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import argparse +import json + +import requests +from src.utils import get_args + + +def test_assistants_http(args): + proxies = {"http": ""} + url = f"http://{args.ip_addr}:{args.ext_port}/v1" + + def process_request(api, query, is_stream=False): + content = json.dumps(query) if query is not None else None + print(f"send request to {url}/{api}, data is {content}") + try: + resp = requests.post(url=f"{url}/{api}", data=content, proxies=proxies, stream=is_stream) + if not is_stream: + ret = resp.json() + print(ret) + else: + for line in resp.iter_lines(decode_unicode=True): + print(line) + ret = None + + resp.raise_for_status() # Raise an exception for unsuccessful HTTP status codes + return ret + except requests.exceptions.RequestException as e: + ret = f"An error occurred:{e}" + print(ret) + return False + + # step 1. create assistants + query = {} + if ret := process_request("assistants", query): + assistant_id = ret.get("id") + print("Created Assistant Id: ", assistant_id) + else: + print("Error when creating assistants !!!!") + return + + # step 2. create threads + query = {} + if ret := process_request("threads", query): + thread_id = ret.get("id") + print("Created Thread Id: ", thread_id) + else: + print("Error when creating threads !!!!") + return + + # step 3. add messages + if args.query is None: + query = {"role": "user", "content": "How old was Bill Gates when he built Microsoft?"} + else: + query = {"role": "user", "content": args.query} + if ret := process_request(f"threads/{thread_id}/messages", query): + pass + else: + print("Error when add messages !!!!") + return + + # step 4. run + print("You may cancel the running process with cmdline") + print(f"curl {url}/threads/{thread_id}/runs/cancel -X POST -H 'Content-Type: application/json'") + + query = {"assistant_id": assistant_id} + process_request(f"threads/{thread_id}/runs", query, is_stream=True) + + +if __name__ == "__main__": + args1, _ = get_args() + parser = argparse.ArgumentParser() + parser.add_argument("--strategy", type=str, default="react") + parser.add_argument("--local_test", action="store_true", help="Test with local mode") + parser.add_argument("--endpoint_test", action="store_true", help="Test with endpoint mode") + parser.add_argument("--assistants_api_test", action="store_true", help="Test with endpoint mode") + parser.add_argument("--q", type=int, default=0) + parser.add_argument("--ip_addr", type=str, default="127.0.0.1", help="endpoint ip address") + parser.add_argument("--ext_port", type=str, default="9090", help="endpoint port") + parser.add_argument("--query", type=str, default=None) + parser.add_argument("--filedir", type=str, default="./", help="test file directory") + parser.add_argument("--filename", type=str, default="query.csv", help="query_list_file") + parser.add_argument("--output", type=str, default="output.csv", help="query_list_file") + parser.add_argument("--ut", action="store_true", help="ut") + + args, _ = parser.parse_known_args() + + for key, value in vars(args1).items(): + setattr(args, key, value) + + if args.assistants_api_test: + print("test args:", args) + test_assistants_http(args) + else: + print("Please specify the test type") diff --git a/comps/cores/proto/api_protocol.py b/comps/cores/proto/api_protocol.py index 93602cebb..0b3094e1c 100644 --- a/comps/cores/proto/api_protocol.py +++ b/comps/cores/proto/api_protocol.py @@ -389,6 +389,82 @@ class ErrorResponse(BaseModel): code: int +class ThreadObject(BaseModel): + id: str + object: str = "thread" + created_at: int + + +class AssistantsObject(BaseModel): + id: str + object: str = "assistant" + created_at: int + name: Optional[str] = None + description: Optional[str] = None + model: Optional[str] = "Intel/neural-chat-7b-v3-3" + instructions: Optional[str] = None + tools: Optional[List[ChatCompletionToolsParam]] = None + + +class Attachments(BaseModel): + file_list: List[UploadFile] = [] + + +class MessageContent(BaseModel): + type: str = "text" + text: Optional[str] = None + + +class MessageObject(BaseModel): + id: str + object: str = "thread.message" + created_at: int + thread_id: str + role: str + status: Optional[str] = None + content: List[MessageContent] + assistant_id: Optional[str] = None + run_id: Optional[str] = None + attachments: Attachments = None + + +class RunObject(BaseModel): + id: str + object: str = "run" + created_at: int + thread_id: str + assistant_id: str + status: Optional[str] = None + last_error: Optional[str] = None + + +class CreateAssistantsRequest(BaseModel): + model: Optional[str] = None + name: Optional[str] = None + description: Optional[str] = None + instructions: Optional[str] = None + tools: Optional[List[ChatCompletionToolsParam]] = None + + +class CreateMessagesRequest(BaseModel): + role: str = "user" + content: Union[str, List[MessageContent]] + attachments: Attachments = None + + +class CreateThreadsRequest(BaseModel): + messages: Optional[List[CreateMessagesRequest]] = None + + +class CreateRunResponse(BaseModel): + assistant_id: str + + +class ListAssistantsRequest(BaseModel): + limit: int = 10 + order: Optional[str] = "desc" + + class ApiErrorCode(IntEnum): """ https://platform.openai.com/docs/guides/error-codes/api-errors diff --git a/tests/test_agent_langchain.sh b/tests/test_agent_langchain.sh index 8efaa971d..6bd97f8d2 100644 --- a/tests/test_agent_langchain.sh +++ b/tests/test_agent_langchain.sh @@ -34,12 +34,10 @@ function start_tgi_service() { echo "start tgi gaudi service" docker run -d --runtime=habana --name "test-comps-tgi-gaudi-service" -p $tgi_port:80 -v $tgi_volume:/data -e HF_TOKEN=$HF_TOKEN -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --ipc=host ghcr.io/huggingface/tgi-gaudi:latest --model-id $model --max-input-tokens 4096 --max-total-tokens 8092 sleep 5s - docker logs test-comps-tgi-gaudi-service - echo "Waiting tgi gaudi ready" n=0 until [[ "$n" -ge 100 ]] || [[ $ready == true ]]; do - docker logs test-comps-tgi-gaudi-service + docker logs test-comps-tgi-gaudi-service &> ${LOG_PATH}/tgi-gaudi-service.log n=$((n+1)) if grep -q Connected ${WORKPATH}/tests/tgi-gaudi-service.log; then break @@ -47,14 +45,14 @@ function start_tgi_service() { sleep 5s done sleep 5s - docker logs test-comps-tgi-gaudi-service echo "Service started successfully" } function start_react_langchain_agent_service() { echo "Starting react_langchain agent microservice" - docker run -d --runtime=runc --name="comps-agent-endpoint" -v $WORKPATH/comps/agent/langchain/tools:/home/user/comps/agent/langchain/tools -p 9090:9090 --ipc=host -e HUGGINGFACEHUB_API_TOKEN=${HUGGINGFACEHUB_API_TOKEN} -e model=${model} -e strategy=react_langchain -e llm_endpoint_url=http://${ip_address}:${tgi_port} -e llm_engine=tgi -e recursion_limit=10 -e require_human_feedback=false -e tools=/home/user/comps/agent/langchain/tools/custom_tools.yaml opea/comps-agent-langchain:comps + docker run -d --runtime=runc --name="comps-agent-endpoint" -v $WORKPATH/comps/agent/langchain/tools:/home/user/comps/agent/langchain/tools -p 5042:9090 --ipc=host -e HUGGINGFACEHUB_API_TOKEN=${HUGGINGFACEHUB_API_TOKEN} -e model=${model} -e strategy=react_langchain -e llm_endpoint_url=http://${ip_address}:${tgi_port} -e llm_engine=tgi -e recursion_limit=10 -e require_human_feedback=false -e tools=/home/user/comps/agent/langchain/tools/custom_tools.yaml opea/comps-agent-langchain:comps sleep 5s + docker logs comps-agent-endpoint echo "Service started successfully" } @@ -62,7 +60,7 @@ function start_react_langchain_agent_service() { function start_react_langgraph_agent_service() { echo "Starting react_langgraph agent microservice" - docker run -d --runtime=runc --name="comps-agent-endpoint" -v $WORKPATH/comps/agent/langchain/tools:/home/user/comps/agent/langchain/tools -p 9090:9090 --ipc=host -e HUGGINGFACEHUB_API_TOKEN=${HUGGINGFACEHUB_API_TOKEN} -e model=${model} -e strategy=react_langgraph -e llm_endpoint_url=http://${ip_address}:${tgi_port} -e llm_engine=tgi -e recursion_limit=10 -e require_human_feedback=false -e tools=/home/user/comps/agent/langchain/tools/custom_tools.yaml opea/comps-agent-langchain:comps + docker run -d --runtime=runc --name="comps-agent-endpoint" -v $WORKPATH/comps/agent/langchain/tools:/home/user/comps/agent/langchain/tools -p 5042:9090 --ipc=host -e HUGGINGFACEHUB_API_TOKEN=${HUGGINGFACEHUB_API_TOKEN} -e model=${model} -e strategy=react_langgraph -e llm_endpoint_url=http://${ip_address}:${tgi_port} -e llm_engine=tgi -e recursion_limit=10 -e require_human_feedback=false -e tools=/home/user/comps/agent/langchain/tools/custom_tools.yaml opea/comps-agent-langchain:comps sleep 5s docker logs comps-agent-endpoint echo "Service started successfully" @@ -70,7 +68,7 @@ function start_react_langgraph_agent_service() { function start_react_langgraph_agent_service_openai() { echo "Starting react_langgraph agent microservice" - docker run -d --runtime=runc --name="comps-agent-endpoint" -v $WORKPATH/comps/agent/langchain/tools:/home/user/comps/agent/langchain/tools -p 9090:9090 --ipc=host -e model=gpt-4o-mini-2024-07-18 -e strategy=react_langgraph -e llm_engine=openai -e OPENAI_API_KEY=${OPENAI_API_KEY} -e recursion_limit=10 -e require_human_feedback=false -e tools=/home/user/comps/agent/langchain/tools/custom_tools.yaml opea/comps-agent-langchain:comps + docker run -d --runtime=runc --name="comps-agent-endpoint" -v $WORKPATH/comps/agent/langchain/tools:/home/user/comps/agent/langchain/tools -p 5042:9090 --ipc=host -e model=gpt-4o-mini-2024-07-18 -e strategy=react_langgraph -e llm_engine=openai -e OPENAI_API_KEY=${OPENAI_API_KEY} -e recursion_limit=10 -e require_human_feedback=false -e tools=/home/user/comps/agent/langchain/tools/custom_tools.yaml opea/comps-agent-langchain:comps sleep 5s docker logs comps-agent-endpoint echo "Service started successfully" @@ -79,7 +77,7 @@ function start_react_langgraph_agent_service_openai() { function start_ragagent_agent_service() { echo "Starting rag agent microservice" - docker run -d --runtime=runc --name="comps-agent-endpoint" -v $WORKPATH/comps/agent/langchain/tools:/home/user/comps/agent/langchain/tools -p 9090:9090 --ipc=host -e HUGGINGFACEHUB_API_TOKEN=${HUGGINGFACEHUB_API_TOKEN} -e model=${model} -e strategy=rag_agent -e llm_endpoint_url=http://${ip_address}:${tgi_port} -e llm_engine=tgi -e recursion_limit=10 -e require_human_feedback=false -e tools=/home/user/comps/agent/langchain/tools/custom_tools.yaml opea/comps-agent-langchain:comps + docker run -d --runtime=runc --name="comps-agent-endpoint" -v $WORKPATH/comps/agent/langchain/tools:/home/user/comps/agent/langchain/tools -p 5042:9090 --ipc=host -e HUGGINGFACEHUB_API_TOKEN=${HUGGINGFACEHUB_API_TOKEN} -e model=${model} -e strategy=rag_agent -e llm_endpoint_url=http://${ip_address}:${tgi_port} -e llm_engine=tgi -e recursion_limit=10 -e require_human_feedback=false -e tools=/home/user/comps/agent/langchain/tools/custom_tools.yaml opea/comps-agent-langchain:comps sleep 5s docker logs comps-agent-endpoint echo "Service started successfully" @@ -101,8 +99,8 @@ function validate() { } function validate_microservice() { - echo "Testing agent service" - local CONTENT=$(http_proxy="" curl http://${ip_address}:9090/v1/chat/completions -X POST -H "Content-Type: application/json" -d '{ + echo "Testing agent service - chat completion API" + local CONTENT=$(http_proxy="" curl http://${ip_address}:5042/v1/chat/completions -X POST -H "Content-Type: application/json" -d '{ "query": "What is Intel OPEA project?" }') local EXIT_CODE=$(validate "$CONTENT" "OPEA" "test-agent-langchain") @@ -110,10 +108,23 @@ function validate_microservice() { local EXIT_CODE="${EXIT_CODE:0-1}" echo "return value is $EXIT_CODE" if [ "$EXIT_CODE" == "1" ]; then - echo "==============tgi container log ===================" - docker logs test-comps-tgi-gaudi-service - echo "==============agent container log ===================" - docker logs comps-agent-endpoint + docker logs test-comps-tgi-gaudi-service &> ${LOG_PATH}/test-comps-tgi-gaudi-service.log + docker logs comps-agent-endpoint &> ${LOG_PATH}/test-comps-langchain-agent-endpoint.log + exit 1 + fi +} + +function validate_assistant_api() { + cd $WORKPATH + echo "Testing agent service - assistant api" + local CONTENT=$(python3 comps/agent/langchain/test_assistant_api.py --ip_addr ${ip_address} --ext_port 5042 --assistants_api_test --query 'What is Intel OPEA project?' 2>&1 | tee ${LOG_PATH}/test-agent-langchain-assistantsapi.log) + local EXIT_CODE=$(validate "$CONTENT" "OPEA" "test-agent-langchain-assistantsapi") + echo "$EXIT_CODE" + local EXIT_CODE="${EXIT_CODE:0-1}" + echo "return value is $EXIT_CODE" + if [ "$EXIT_CODE" == "1" ]; then + docker logs comps-tgi-gaudi-service &> ${LOG_PATH}/test-comps-tgi-gaudi-service.log + docker logs comps-langchain-agent-endpoint &> ${LOG_PATH}/test-comps-langchain-agent-endpoint.log exit 1 fi } @@ -140,15 +151,22 @@ function stop_docker() { function main() { stop_docker - build_docker_images start_tgi_service + # test rag agent + start_ragagent_agent_service + echo "=============Testing RAG Agent=============" + validate_microservice + stop_agent_docker + echo "=============================================" + # test react_langchain start_react_langchain_agent_service echo "=============Testing ReAct Langchain=============" validate_microservice + validate_assistant_api stop_agent_docker echo "=============================================" @@ -160,16 +178,8 @@ function main() { # stop_agent_docker # echo "=============================================" - - # test rag agent - start_ragagent_agent_service - echo "=============Testing RAG Agent=============" - validate_microservice - echo "=============================================" - stop_docker echo y | docker system prune 2>&1 > /dev/null - } main