diff --git a/libs/langchain/langchain/agents/openai_functions_agent/base.py b/libs/langchain/langchain/agents/openai_functions_agent/base.py index 9e6c6e0cfd74e..f388f5cae2683 100644 --- a/libs/langchain/langchain/agents/openai_functions_agent/base.py +++ b/libs/langchain/langchain/agents/openai_functions_agent/base.py @@ -1,5 +1,5 @@ """Module implements an agent that uses OpenAI's APIs function enabled API.""" -from typing import Any, List, Optional, Sequence, Tuple, Union +from typing import Any, List, Optional, Sequence, Tuple, Type, Union from langchain_community.tools.convert_to_openai import format_tool_to_openai_function from langchain_core._api import deprecated @@ -47,6 +47,9 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent): llm: BaseLanguageModel tools: Sequence[BaseTool] prompt: BasePromptTemplate + output_parser: Type[ + OpenAIFunctionsAgentOutputParser + ] = OpenAIFunctionsAgentOutputParser def get_allowed_tools(self) -> List[str]: """Get allowed tools.""" @@ -105,9 +108,7 @@ def plan( messages, callbacks=callbacks, ) - agent_decision = OpenAIFunctionsAgentOutputParser._parse_ai_message( - predicted_message - ) + agent_decision = self.output_parser._parse_ai_message(predicted_message) return agent_decision async def aplan( @@ -136,9 +137,7 @@ async def aplan( predicted_message = await self.llm.apredict_messages( messages, functions=self.functions, callbacks=callbacks ) - agent_decision = OpenAIFunctionsAgentOutputParser._parse_ai_message( - predicted_message - ) + agent_decision = self.output_parser._parse_ai_message(predicted_message) return agent_decision def return_stopped_response(