From 8d6faf56657070137a16669bc2420a08a5ab7f24 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Sat, 11 Nov 2023 09:01:17 +0000 Subject: [PATCH] Make it easier to subclass runnable binding with custom init args (#13189) --- libs/langchain/langchain/runnables/hub.py | 4 +- .../langchain/runnables/openai_functions.py | 5 +- .../langchain/schema/runnable/base.py | 191 +++++++++--------- .../langchain/schema/runnable/fallbacks.py | 5 - .../langchain/schema/runnable/retry.py | 4 +- 5 files changed, 99 insertions(+), 110 deletions(-) diff --git a/libs/langchain/langchain/runnables/hub.py b/libs/langchain/langchain/runnables/hub.py index 8fc96e3bacff0..64dbe2f61805d 100644 --- a/libs/langchain/langchain/runnables/hub.py +++ b/libs/langchain/langchain/runnables/hub.py @@ -1,9 +1,9 @@ from typing import Any, Optional -from langchain.schema.runnable.base import Input, Output, RunnableBinding +from langchain.schema.runnable.base import Input, Output, RunnableBindingBase -class HubRunnable(RunnableBinding[Input, Output]): +class HubRunnable(RunnableBindingBase[Input, Output]): """ An instance of a runnable stored in the LangChain Hub. """ diff --git a/libs/langchain/langchain/runnables/openai_functions.py b/libs/langchain/langchain/runnables/openai_functions.py index 1ee9f44b971ef..cdabef48fc0bd 100644 --- a/libs/langchain/langchain/runnables/openai_functions.py +++ b/libs/langchain/langchain/runnables/openai_functions.py @@ -5,7 +5,8 @@ from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser from langchain.schema.messages import BaseMessage -from langchain.schema.runnable import RouterRunnable, Runnable, RunnableBinding +from langchain.schema.runnable import RouterRunnable, Runnable +from langchain.schema.runnable.base import RunnableBindingBase class OpenAIFunction(TypedDict): @@ -19,7 +20,7 @@ class OpenAIFunction(TypedDict): """The parameters to the function.""" -class OpenAIFunctionsRouter(RunnableBinding[BaseMessage, Any]): +class OpenAIFunctionsRouter(RunnableBindingBase[BaseMessage, Any]): """A runnable that routes to the selected function.""" functions: Optional[List[OpenAIFunction]] diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index 4308a20253486..7cf6d2af904a4 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -2581,11 +2581,6 @@ def get_output_schema( def config_specs(self) -> Sequence[ConfigurableFieldSpec]: return self.bound.config_specs - def config_schema( - self, *, include: Optional[Sequence[str]] = None - ) -> Type[BaseModel]: - return self.bound.config_schema(include=include) - @classmethod def is_lc_serializable(cls) -> bool: return True @@ -2659,7 +2654,7 @@ async def ainvoke( return await self._acall_with_config(self._ainvoke, input, config, **kwargs) -class RunnableBinding(RunnableSerializable[Input, Output]): +class RunnableBindingBase(RunnableSerializable[Input, Output]): """ A runnable that delegates calls to another runnable with a set of kwargs. """ @@ -2749,11 +2744,6 @@ def get_output_schema( def config_specs(self) -> Sequence[ConfigurableFieldSpec]: return self.bound.config_specs - def config_schema( - self, *, include: Optional[Sequence[str]] = None - ) -> Type[BaseModel]: - return self.bound.config_schema(include=include) - @classmethod def is_lc_serializable(cls) -> bool: return True @@ -2762,93 +2752,6 @@ def is_lc_serializable(cls) -> bool: def get_lc_namespace(cls) -> List[str]: return cls.__module__.split(".")[:-1] - def bind(self, **kwargs: Any) -> Runnable[Input, Output]: - return self.__class__( - bound=self.bound, - config=self.config, - kwargs={**self.kwargs, **kwargs}, - custom_input_type=self.custom_input_type, - custom_output_type=self.custom_output_type, - ) - - def with_config( - self, - config: Optional[RunnableConfig] = None, - # Sadly Unpack is not well supported by mypy so this will have to be untyped - **kwargs: Any, - ) -> Runnable[Input, Output]: - return self.__class__( - bound=self.bound, - kwargs=self.kwargs, - config=cast(RunnableConfig, {**self.config, **(config or {}), **kwargs}), - custom_input_type=self.custom_input_type, - custom_output_type=self.custom_output_type, - ) - - def with_listeners( - self, - *, - on_start: Optional[Listener] = None, - on_end: Optional[Listener] = None, - on_error: Optional[Listener] = None, - ) -> Runnable[Input, Output]: - """ - Bind lifecycle listeners to a Runnable, returning a new Runnable. - - on_start: Called before the runnable starts running, with the Run object. - on_end: Called after the runnable finishes running, with the Run object. - on_error: Called if the runnable throws an error, with the Run object. - - The Run object contains information about the run, including its id, - type, input, output, error, start_time, end_time, and any tags or metadata - added to the run. - """ - from langchain.callbacks.tracers.root_listeners import RootListenersTracer - - return self.__class__( - bound=self.bound, - kwargs=self.kwargs, - config=self.config, - config_factories=[ - lambda config: { - "callbacks": [ - RootListenersTracer( - config=config, - on_start=on_start, - on_end=on_end, - on_error=on_error, - ) - ], - } - ], - custom_input_type=self.custom_input_type, - custom_output_type=self.custom_output_type, - ) - - def with_types( - self, - input_type: Optional[Union[Type[Input], BaseModel]] = None, - output_type: Optional[Union[Type[Output], BaseModel]] = None, - ) -> Runnable[Input, Output]: - return self.__class__( - bound=self.bound, - kwargs=self.kwargs, - config=self.config, - custom_input_type=input_type - if input_type is not None - else self.custom_input_type, - custom_output_type=output_type - if output_type is not None - else self.custom_output_type, - ) - - def with_retry(self, **kwargs: Any) -> Runnable[Input, Output]: - return self.__class__( - bound=self.bound.with_retry(**kwargs), - kwargs=self.kwargs, - config=self.config, - ) - def _merge_configs(self, *configs: Optional[RunnableConfig]) -> RunnableConfig: config = merge_configs(self.config, *configs) return merge_configs(config, *(f(config) for f in self.config_factories)) @@ -2972,7 +2875,97 @@ async def atransform( yield item -RunnableBinding.update_forward_refs(RunnableConfig=RunnableConfig) +RunnableBindingBase.update_forward_refs(RunnableConfig=RunnableConfig) + + +class RunnableBinding(RunnableBindingBase[Input, Output]): + def bind(self, **kwargs: Any) -> Runnable[Input, Output]: + return self.__class__( + bound=self.bound, + config=self.config, + kwargs={**self.kwargs, **kwargs}, + custom_input_type=self.custom_input_type, + custom_output_type=self.custom_output_type, + ) + + def with_config( + self, + config: Optional[RunnableConfig] = None, + # Sadly Unpack is not well supported by mypy so this will have to be untyped + **kwargs: Any, + ) -> Runnable[Input, Output]: + return self.__class__( + bound=self.bound, + kwargs=self.kwargs, + config=cast(RunnableConfig, {**self.config, **(config or {}), **kwargs}), + custom_input_type=self.custom_input_type, + custom_output_type=self.custom_output_type, + ) + + def with_listeners( + self, + *, + on_start: Optional[Listener] = None, + on_end: Optional[Listener] = None, + on_error: Optional[Listener] = None, + ) -> Runnable[Input, Output]: + """ + Bind lifecycle listeners to a Runnable, returning a new Runnable. + + on_start: Called before the runnable starts running, with the Run object. + on_end: Called after the runnable finishes running, with the Run object. + on_error: Called if the runnable throws an error, with the Run object. + + The Run object contains information about the run, including its id, + type, input, output, error, start_time, end_time, and any tags or metadata + added to the run. + """ + from langchain.callbacks.tracers.root_listeners import RootListenersTracer + + return self.__class__( + bound=self.bound, + kwargs=self.kwargs, + config=self.config, + config_factories=[ + lambda config: { + "callbacks": [ + RootListenersTracer( + config=config, + on_start=on_start, + on_end=on_end, + on_error=on_error, + ) + ], + } + ], + custom_input_type=self.custom_input_type, + custom_output_type=self.custom_output_type, + ) + + def with_types( + self, + input_type: Optional[Union[Type[Input], BaseModel]] = None, + output_type: Optional[Union[Type[Output], BaseModel]] = None, + ) -> Runnable[Input, Output]: + return self.__class__( + bound=self.bound, + kwargs=self.kwargs, + config=self.config, + custom_input_type=input_type + if input_type is not None + else self.custom_input_type, + custom_output_type=output_type + if output_type is not None + else self.custom_output_type, + ) + + def with_retry(self, **kwargs: Any) -> Runnable[Input, Output]: + return self.__class__( + bound=self.bound.with_retry(**kwargs), + kwargs=self.kwargs, + config=self.config, + ) + RunnableLike = Union[ Runnable[Input, Output], diff --git a/libs/langchain/langchain/schema/runnable/fallbacks.py b/libs/langchain/langchain/schema/runnable/fallbacks.py index 7d2c834c982bf..cd8e754c07b6b 100644 --- a/libs/langchain/langchain/schema/runnable/fallbacks.py +++ b/libs/langchain/langchain/schema/runnable/fallbacks.py @@ -119,11 +119,6 @@ def config_specs(self) -> Sequence[ConfigurableFieldSpec]: for spec in step.config_specs ) - def config_schema( - self, *, include: Optional[Sequence[str]] = None - ) -> Type[BaseModel]: - return self.runnable.config_schema(include=include) - @classmethod def is_lc_serializable(cls) -> bool: return True diff --git a/libs/langchain/langchain/schema/runnable/retry.py b/libs/langchain/langchain/schema/runnable/retry.py index 504857a4c663e..99b665bf1eec0 100644 --- a/libs/langchain/langchain/schema/runnable/retry.py +++ b/libs/langchain/langchain/schema/runnable/retry.py @@ -21,7 +21,7 @@ wait_exponential_jitter, ) -from langchain.schema.runnable.base import Input, Output, RunnableBinding +from langchain.schema.runnable.base import Input, Output, RunnableBindingBase from langchain.schema.runnable.config import RunnableConfig, patch_config if TYPE_CHECKING: @@ -34,7 +34,7 @@ U = TypeVar("U") -class RunnableRetry(RunnableBinding[Input, Output]): +class RunnableRetry(RunnableBindingBase[Input, Output]): """Retry a Runnable if it fails. A RunnableRetry helps can be used to add retry logic to any object