Skip to content

Commit

Permalink
Enable customizing the output parser of OpenAIFunctionsAgent (langc…
Browse files Browse the repository at this point in the history
…hain-ai#15827)

- **Description:** This PR defines the output parser of
OpenAIFunctionsAgent as an attribute, enabling customization and
subclassing of the parser logic.
- **Issue:** Subclassing is currently impossible as the
`OpenAIFunctionsAgentOutputParser` class is hard coded into the `plan`
and `aplan` methods
  - **Dependencies:** None

<!-- Thank you for contributing to LangChain!

Please title your PR "<package>: <description>", where <package> is
whichever of langchain, community, core, experimental, etc. is being
modified.

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes if applicable,
  - **Dependencies:** any dependencies required for this change,
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` from the root
of the package you've modified to check this locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc: https://python.langchain.com/docs/contributing/

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->

---------

Co-authored-by: Harrison Chase <[email protected]>
  • Loading branch information
treisfeld and hwchase17 authored Jan 12, 2024
1 parent 560bb49 commit eb9b334
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions libs/langchain/langchain/agents/openai_functions_agent/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Module implements an agent that uses OpenAI's APIs function enabled API."""
from typing import Any, List, Optional, Sequence, Tuple, Union
from typing import Any, List, Optional, Sequence, Tuple, Type, Union

from langchain_community.tools.convert_to_openai import format_tool_to_openai_function
from langchain_core._api import deprecated
Expand Down Expand Up @@ -47,6 +47,9 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
llm: BaseLanguageModel
tools: Sequence[BaseTool]
prompt: BasePromptTemplate
output_parser: Type[
OpenAIFunctionsAgentOutputParser
] = OpenAIFunctionsAgentOutputParser

def get_allowed_tools(self) -> List[str]:
"""Get allowed tools."""
Expand Down Expand Up @@ -105,9 +108,7 @@ def plan(
messages,
callbacks=callbacks,
)
agent_decision = OpenAIFunctionsAgentOutputParser._parse_ai_message(
predicted_message
)
agent_decision = self.output_parser._parse_ai_message(predicted_message)
return agent_decision

async def aplan(
Expand Down Expand Up @@ -136,9 +137,7 @@ async def aplan(
predicted_message = await self.llm.apredict_messages(
messages, functions=self.functions, callbacks=callbacks
)
agent_decision = OpenAIFunctionsAgentOutputParser._parse_ai_message(
predicted_message
)
agent_decision = self.output_parser._parse_ai_message(predicted_message)
return agent_decision

def return_stopped_response(
Expand Down

0 comments on commit eb9b334

Please sign in to comment.