Skip to content

Commit

Permalink
Add Assistant API for agent (opea-project#490)
Browse files Browse the repository at this point in the history
* Add assistants api support

Signed-off-by: Chendi.Xue <[email protected]>

* update UT to match new port definition

Signed-off-by: Chendi.Xue <[email protected]>

* rebase

Signed-off-by: Chendi.Xue <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix gramma in Test

Signed-off-by: Chendi.Xue <[email protected]>

---------

Signed-off-by: Chendi.Xue <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
xuechendi and pre-commit-ci[bot] authored Aug 20, 2024
1 parent 5dedd04 commit f3a8935
Show file tree
Hide file tree
Showing 13 changed files with 568 additions and 59 deletions.
161 changes: 153 additions & 8 deletions comps/agent/langchain/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import os
import pathlib
import sys
from datetime import datetime
from typing import Union

from fastapi.responses import StreamingResponse

Expand All @@ -14,7 +16,21 @@

from comps import CustomLogger, GeneratedDoc, LLMParamsDoc, ServiceType, opea_microservices, register_microservice
from comps.agent.langchain.src.agent import instantiate_agent
from comps.agent.langchain.src.global_var import assistants_global_kv, threads_global_kv
from comps.agent.langchain.src.thread import instantiate_thread_memory, thread_completion_callback
from comps.agent.langchain.src.utils import get_args
from comps.cores.proto.api_protocol import (
AssistantsObject,
ChatCompletionRequest,
CreateAssistantsRequest,
CreateMessagesRequest,
CreateRunResponse,
CreateThreadsRequest,
MessageContent,
MessageObject,
RunObject,
ThreadObject,
)

logger = CustomLogger("comps-react-agent")
logflag = os.getenv("LOGFLAG", False)
Expand All @@ -23,14 +39,13 @@


@register_microservice(
name="opea_service@comps-react-agent",
name="opea_service@comps-chat-agent",
service_type=ServiceType.LLM,
endpoint="/v1/chat/completions",
host="0.0.0.0",
port=args.port,
input_datatype=LLMParamsDoc,
)
async def llm_generate(input: LLMParamsDoc):
async def llm_generate(input: Union[LLMParamsDoc, ChatCompletionRequest]):
if logflag:
logger.info(input)
# 1. initialize the agent
Expand All @@ -42,19 +57,149 @@ async def llm_generate(input: LLMParamsDoc):
if logflag:
logger.info(type(agent_inst))

if isinstance(input, LLMParamsDoc):
# use query as input
input_query = input.query
else:
# openai compatible input
if isinstance(input.messages, str):
input_query = input.messages
else:
input_query = input.messages[-1]["content"]

# 2. prepare the input for the agent
if input.streaming:
print("-----------STREAMING-------------")
return StreamingResponse(agent_inst.stream_generator(input.query, config), media_type="text/event-stream")
return StreamingResponse(agent_inst.stream_generator(input_query, config), media_type="text/event-stream")

else:
# TODO: add support for non-streaming mode
print("-----------NOT STREAMING-------------")
response = await agent_inst.non_streaming_run(input.query, config)
response = await agent_inst.non_streaming_run(input_query, config)
print("-----------Response-------------")
print(response)
return GeneratedDoc(text=response, prompt=input.query)
return GeneratedDoc(text=response, prompt=input_query)


@register_microservice(
name="opea_service@comps-chat-agent",
endpoint="/v1/assistants",
host="0.0.0.0",
port=args.port,
)
def create_assistants(input: CreateAssistantsRequest):
# 1. initialize the agent
print("args: ", args)
agent_inst = instantiate_agent(args, args.strategy, with_memory=True)
agent_id = agent_inst.id
created_at = int(datetime.now().timestamp())
with assistants_global_kv as g_assistants:
g_assistants[agent_id] = (agent_inst, created_at)
print(f"Record assistant inst {agent_id} in global KV")

# get current time in string format
return AssistantsObject(
id=agent_id,
created_at=created_at,
)


@register_microservice(
name="opea_service@comps-chat-agent",
endpoint="/v1/threads",
host="0.0.0.0",
port=args.port,
)
def create_threads(input: CreateThreadsRequest):
# create a memory KV for the thread
thread_inst, thread_id = instantiate_thread_memory()
created_at = int(datetime.now().timestamp())
status = "ready"
with threads_global_kv as g_threads:
g_threads[thread_id] = (thread_inst, created_at, status)
print(f"Record thread inst {thread_id} in global KV")

return ThreadObject(
id=thread_id,
created_at=created_at,
)


@register_microservice(
name="opea_service@comps-chat-agent",
endpoint="/v1/threads/{thread_id}/messages",
host="0.0.0.0",
port=args.port,
)
def create_messages(thread_id, input: CreateMessagesRequest):
with threads_global_kv as g_threads:
thread_inst, _, _ = g_threads[thread_id]

# create a memory KV for the message
role = input.role
if isinstance(input.content, str):
query = input.content
else:
query = input.content[-1]["text"]
msg_id, created_at = thread_inst.add_query(query)

structured_content = MessageContent(text=query)
return MessageObject(
id=msg_id,
created_at=created_at,
thread_id=thread_id,
role=role,
content=[structured_content],
)


@register_microservice(
name="opea_service@comps-chat-agent",
endpoint="/v1/threads/{thread_id}/runs",
host="0.0.0.0",
port=args.port,
)
def create_run(thread_id, input: CreateRunResponse):
with threads_global_kv as g_threads:
thread_inst, _, status = g_threads[thread_id]

if status == "running":
return "[error] Thread is already running, need to cancel the current run or wait for it to finish"

agent_id = input.assistant_id
with assistants_global_kv as g_assistants:
agent_inst, _ = g_assistants[agent_id]

config = {"recursion_limit": args.recursion_limit}
input_query = thread_inst.get_query()
try:
return StreamingResponse(
thread_completion_callback(agent_inst.stream_generator(input_query, config, thread_id), thread_id),
media_type="text/event-stream",
)
except Exception as e:
with threads_global_kv as g_threads:
thread_inst, created_at, status = g_threads[thread_id]
g_threads[thread_id] = (thread_inst, created_at, "ready")
return f"An error occurred: {e}. This thread is now set as ready"


@register_microservice(
name="opea_service@comps-chat-agent",
endpoint="/v1/threads/{thread_id}/runs/cancel",
host="0.0.0.0",
port=args.port,
)
def cancel_run(thread_id):
with threads_global_kv as g_threads:
thread_inst, created_at, status = g_threads[thread_id]
if status == "ready":
return "Thread is not running, no need to cancel"
elif status == "try_cancel":
return "cancel request is submitted"
else:
g_threads[thread_id] = (thread_inst, created_at, "try_cancel")
return "submit cancel request"


if __name__ == "__main__":
opea_microservices["opea_service@comps-react-agent"].start()
opea_microservices["opea_service@comps-chat-agent"].start()
7 changes: 4 additions & 3 deletions comps/agent/langchain/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@ docarray[full]
duckduckgo-search
fastapi
huggingface_hub==0.24.0
langchain #==0.1.12
langchain==0.2.9
langchain-huggingface
langchain-openai
langchain_community
langchainhub
langchain_community==0.2.7
langchainhub==0.1.20
langgraph
langsmith
numpy

# used by cloud native
Expand Down
10 changes: 5 additions & 5 deletions comps/agent/langchain/src/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,23 @@
# SPDX-License-Identifier: Apache-2.0


def instantiate_agent(args, strategy="react_langchain"):
def instantiate_agent(args, strategy="react_langchain", with_memory=False):
if strategy == "react_langchain":
from .strategy.react import ReActAgentwithLangchain

return ReActAgentwithLangchain(args)
return ReActAgentwithLangchain(args, with_memory)
elif strategy == "react_langgraph":
from .strategy.react import ReActAgentwithLanggraph

return ReActAgentwithLanggraph(args)
return ReActAgentwithLanggraph(args, with_memory)
elif strategy == "plan_execute":
from .strategy.planexec import PlanExecuteAgentWithLangGraph

return PlanExecuteAgentWithLangGraph(args)
return PlanExecuteAgentWithLangGraph(args, with_memory)

elif strategy == "rag_agent":
from .strategy.ragagent import RAGAgent

return RAGAgent(args)
return RAGAgent(args, with_memory)
else:
raise ValueError(f"Agent strategy: {strategy} not supported!")
21 changes: 21 additions & 0 deletions comps/agent/langchain/src/global_var.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import threading


class ThreadSafeDict(dict):
def __init__(self, *p_arg, **n_arg):
dict.__init__(self, *p_arg, **n_arg)
self._lock = threading.Lock()

def __enter__(self):
self._lock.acquire()
return self

def __exit__(self, type, value, traceback):
self._lock.release()


assistants_global_kv = ThreadSafeDict()
threads_global_kv = ThreadSafeDict()
4 changes: 4 additions & 0 deletions comps/agent/langchain/src/strategy/base_agent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from uuid import uuid4

from ..tools import get_tools_descriptions
from ..utils import setup_llm

Expand All @@ -10,6 +12,8 @@ def __init__(self, args) -> None:
self.llm_endpoint = setup_llm(args)
self.tools_descriptions = get_tools_descriptions(args.tools)
self.app = None
self.memory = None
self.id = f"assistant_{self.__class__.__name__}_{uuid4()}"
print(self.tools_descriptions)

def compile(self):
Expand Down
32 changes: 18 additions & 14 deletions comps/agent/langchain/src/strategy/planexec/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.utils.json import parse_partial_json
from langchain_huggingface import ChatHuggingFace
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import add_messages

from ...global_var import threads_global_kv
from ...utils import has_multi_tool_inputs, tool_renderer
from ..base_agent import BaseAgent
from .prompt import (
Expand Down Expand Up @@ -221,7 +223,7 @@ def __call__(self, state):


class PlanExecuteAgentWithLangGraph(BaseAgent):
def __init__(self, args):
def __init__(self, args, with_memory=False):
super().__init__(args)

# Define Node
Expand All @@ -231,37 +233,39 @@ def __init__(self, args):
execute_step = Executor(self.llm_endpoint, args.model, self.tools_descriptions)
make_answer = AnswerMaker(self.llm_endpoint, args.model)

# answer_checker = FinalAnswerChecker(self.llm_endpoint, args.model)
# replan_step = Replanner(self.llm_endpoint, args.model, answer_checker)

# Define Graph
workflow = StateGraph(PlanExecute)
workflow.add_node("planner", plan_step)
workflow.add_node("plan_executor", execute_step)
workflow.add_node("answer_maker", make_answer)
# workflow.add_node("replan", replan_step)

# Define edges
workflow.add_edge(START, "planner")
workflow.add_edge("planner", "plan_executor")
workflow.add_edge("plan_executor", "answer_maker")
workflow.add_edge("answer_maker", END)
# workflow.add_conditional_edges(
# "answer_maker",
# answer_checker,
# {END: END, "replan": "replan"},
# )
# workflow.add_edge("replan", "plan_executor")

# Finally, we compile it!
self.app = workflow.compile()
if with_memory:
self.app = workflow.compile(checkpointer=MemorySaver())
else:
self.app = workflow.compile()

def prepare_initial_state(self, query):
return {"messages": [("user", query)]}

async def stream_generator(self, query, config):
async def stream_generator(self, query, config, thread_id=None):
initial_state = self.prepare_initial_state(query)
if thread_id is not None:
config["configurable"] = {"thread_id": thread_id}
async for event in self.app.astream(initial_state, config=config):
if thread_id is not None:
with threads_global_kv as g_threads:
thread_inst, created_at, status = g_threads[thread_id]
if status == "try_cancel":
yield "[thread_completion_callback] signal to cancel! Changed status to ready"
print("[thread_completion_callback] signal to cancel! Changed status to ready")
g_threads[thread_id] = (thread_inst, created_at, "ready")
break
for node_name, node_state in event.items():
yield f"--- CALL {node_name} ---\n"
for k, v in node_state.items():
Expand Down
8 changes: 6 additions & 2 deletions comps/agent/langchain/src/strategy/ragagent/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
from langchain_openai import ChatOpenAI
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode, tools_condition
Expand Down Expand Up @@ -154,7 +155,7 @@ def __call__(self, state):


class RAGAgent(BaseAgent):
def __init__(self, args):
def __init__(self, args, with_memory=False):
super().__init__(args)

# Define Nodes
Expand Down Expand Up @@ -195,7 +196,10 @@ def __init__(self, args):
)
workflow.add_edge("generate", END)

self.app = workflow.compile()
if with_memory:
self.app = workflow.compile(checkpointer=MemorySaver())
else:
self.app = workflow.compile()

def should_retry(self, state):
# first check how many retry attempts have been made
Expand Down
Loading

0 comments on commit f3a8935

Please sign in to comment.