Skip to content

Commit

Permalink
Make it easier to subclass runnable binding with custom init args (la…
Browse files Browse the repository at this point in the history
  • Loading branch information
nfcampos authored Nov 11, 2023
1 parent 7f1964b commit 8d6faf5
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 110 deletions.
4 changes: 2 additions & 2 deletions libs/langchain/langchain/runnables/hub.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Any, Optional

from langchain.schema.runnable.base import Input, Output, RunnableBinding
from langchain.schema.runnable.base import Input, Output, RunnableBindingBase


class HubRunnable(RunnableBinding[Input, Output]):
class HubRunnable(RunnableBindingBase[Input, Output]):
"""
An instance of a runnable stored in the LangChain Hub.
"""
Expand Down
5 changes: 3 additions & 2 deletions libs/langchain/langchain/runnables/openai_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
from langchain.schema.messages import BaseMessage
from langchain.schema.runnable import RouterRunnable, Runnable, RunnableBinding
from langchain.schema.runnable import RouterRunnable, Runnable
from langchain.schema.runnable.base import RunnableBindingBase


class OpenAIFunction(TypedDict):
Expand All @@ -19,7 +20,7 @@ class OpenAIFunction(TypedDict):
"""The parameters to the function."""


class OpenAIFunctionsRouter(RunnableBinding[BaseMessage, Any]):
class OpenAIFunctionsRouter(RunnableBindingBase[BaseMessage, Any]):
"""A runnable that routes to the selected function."""

functions: Optional[List[OpenAIFunction]]
Expand Down
191 changes: 92 additions & 99 deletions libs/langchain/langchain/schema/runnable/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2581,11 +2581,6 @@ def get_output_schema(
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
return self.bound.config_specs

def config_schema(
self, *, include: Optional[Sequence[str]] = None
) -> Type[BaseModel]:
return self.bound.config_schema(include=include)

@classmethod
def is_lc_serializable(cls) -> bool:
return True
Expand Down Expand Up @@ -2659,7 +2654,7 @@ async def ainvoke(
return await self._acall_with_config(self._ainvoke, input, config, **kwargs)


class RunnableBinding(RunnableSerializable[Input, Output]):
class RunnableBindingBase(RunnableSerializable[Input, Output]):
"""
A runnable that delegates calls to another runnable with a set of kwargs.
"""
Expand Down Expand Up @@ -2749,11 +2744,6 @@ def get_output_schema(
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
return self.bound.config_specs

def config_schema(
self, *, include: Optional[Sequence[str]] = None
) -> Type[BaseModel]:
return self.bound.config_schema(include=include)

@classmethod
def is_lc_serializable(cls) -> bool:
return True
Expand All @@ -2762,93 +2752,6 @@ def is_lc_serializable(cls) -> bool:
def get_lc_namespace(cls) -> List[str]:
return cls.__module__.split(".")[:-1]

def bind(self, **kwargs: Any) -> Runnable[Input, Output]:
return self.__class__(
bound=self.bound,
config=self.config,
kwargs={**self.kwargs, **kwargs},
custom_input_type=self.custom_input_type,
custom_output_type=self.custom_output_type,
)

def with_config(
self,
config: Optional[RunnableConfig] = None,
# Sadly Unpack is not well supported by mypy so this will have to be untyped
**kwargs: Any,
) -> Runnable[Input, Output]:
return self.__class__(
bound=self.bound,
kwargs=self.kwargs,
config=cast(RunnableConfig, {**self.config, **(config or {}), **kwargs}),
custom_input_type=self.custom_input_type,
custom_output_type=self.custom_output_type,
)

def with_listeners(
self,
*,
on_start: Optional[Listener] = None,
on_end: Optional[Listener] = None,
on_error: Optional[Listener] = None,
) -> Runnable[Input, Output]:
"""
Bind lifecycle listeners to a Runnable, returning a new Runnable.
on_start: Called before the runnable starts running, with the Run object.
on_end: Called after the runnable finishes running, with the Run object.
on_error: Called if the runnable throws an error, with the Run object.
The Run object contains information about the run, including its id,
type, input, output, error, start_time, end_time, and any tags or metadata
added to the run.
"""
from langchain.callbacks.tracers.root_listeners import RootListenersTracer

return self.__class__(
bound=self.bound,
kwargs=self.kwargs,
config=self.config,
config_factories=[
lambda config: {
"callbacks": [
RootListenersTracer(
config=config,
on_start=on_start,
on_end=on_end,
on_error=on_error,
)
],
}
],
custom_input_type=self.custom_input_type,
custom_output_type=self.custom_output_type,
)

def with_types(
self,
input_type: Optional[Union[Type[Input], BaseModel]] = None,
output_type: Optional[Union[Type[Output], BaseModel]] = None,
) -> Runnable[Input, Output]:
return self.__class__(
bound=self.bound,
kwargs=self.kwargs,
config=self.config,
custom_input_type=input_type
if input_type is not None
else self.custom_input_type,
custom_output_type=output_type
if output_type is not None
else self.custom_output_type,
)

def with_retry(self, **kwargs: Any) -> Runnable[Input, Output]:
return self.__class__(
bound=self.bound.with_retry(**kwargs),
kwargs=self.kwargs,
config=self.config,
)

def _merge_configs(self, *configs: Optional[RunnableConfig]) -> RunnableConfig:
config = merge_configs(self.config, *configs)
return merge_configs(config, *(f(config) for f in self.config_factories))
Expand Down Expand Up @@ -2972,7 +2875,97 @@ async def atransform(
yield item


RunnableBinding.update_forward_refs(RunnableConfig=RunnableConfig)
RunnableBindingBase.update_forward_refs(RunnableConfig=RunnableConfig)


class RunnableBinding(RunnableBindingBase[Input, Output]):
def bind(self, **kwargs: Any) -> Runnable[Input, Output]:
return self.__class__(
bound=self.bound,
config=self.config,
kwargs={**self.kwargs, **kwargs},
custom_input_type=self.custom_input_type,
custom_output_type=self.custom_output_type,
)

def with_config(
self,
config: Optional[RunnableConfig] = None,
# Sadly Unpack is not well supported by mypy so this will have to be untyped
**kwargs: Any,
) -> Runnable[Input, Output]:
return self.__class__(
bound=self.bound,
kwargs=self.kwargs,
config=cast(RunnableConfig, {**self.config, **(config or {}), **kwargs}),
custom_input_type=self.custom_input_type,
custom_output_type=self.custom_output_type,
)

def with_listeners(
self,
*,
on_start: Optional[Listener] = None,
on_end: Optional[Listener] = None,
on_error: Optional[Listener] = None,
) -> Runnable[Input, Output]:
"""
Bind lifecycle listeners to a Runnable, returning a new Runnable.
on_start: Called before the runnable starts running, with the Run object.
on_end: Called after the runnable finishes running, with the Run object.
on_error: Called if the runnable throws an error, with the Run object.
The Run object contains information about the run, including its id,
type, input, output, error, start_time, end_time, and any tags or metadata
added to the run.
"""
from langchain.callbacks.tracers.root_listeners import RootListenersTracer

return self.__class__(
bound=self.bound,
kwargs=self.kwargs,
config=self.config,
config_factories=[
lambda config: {
"callbacks": [
RootListenersTracer(
config=config,
on_start=on_start,
on_end=on_end,
on_error=on_error,
)
],
}
],
custom_input_type=self.custom_input_type,
custom_output_type=self.custom_output_type,
)

def with_types(
self,
input_type: Optional[Union[Type[Input], BaseModel]] = None,
output_type: Optional[Union[Type[Output], BaseModel]] = None,
) -> Runnable[Input, Output]:
return self.__class__(
bound=self.bound,
kwargs=self.kwargs,
config=self.config,
custom_input_type=input_type
if input_type is not None
else self.custom_input_type,
custom_output_type=output_type
if output_type is not None
else self.custom_output_type,
)

def with_retry(self, **kwargs: Any) -> Runnable[Input, Output]:
return self.__class__(
bound=self.bound.with_retry(**kwargs),
kwargs=self.kwargs,
config=self.config,
)


RunnableLike = Union[
Runnable[Input, Output],
Expand Down
5 changes: 0 additions & 5 deletions libs/langchain/langchain/schema/runnable/fallbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,6 @@ def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
for spec in step.config_specs
)

def config_schema(
self, *, include: Optional[Sequence[str]] = None
) -> Type[BaseModel]:
return self.runnable.config_schema(include=include)

@classmethod
def is_lc_serializable(cls) -> bool:
return True
Expand Down
4 changes: 2 additions & 2 deletions libs/langchain/langchain/schema/runnable/retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
wait_exponential_jitter,
)

from langchain.schema.runnable.base import Input, Output, RunnableBinding
from langchain.schema.runnable.base import Input, Output, RunnableBindingBase
from langchain.schema.runnable.config import RunnableConfig, patch_config

if TYPE_CHECKING:
Expand All @@ -34,7 +34,7 @@
U = TypeVar("U")


class RunnableRetry(RunnableBinding[Input, Output]):
class RunnableRetry(RunnableBindingBase[Input, Output]):
"""Retry a Runnable if it fails.
A RunnableRetry helps can be used to add retry logic to any object
Expand Down

0 comments on commit 8d6faf5

Please sign in to comment.