Skip to content

Commit

Permalink
feat(weave): Support op configuration for autopatched functions for r…
Browse files Browse the repository at this point in the history
…emaining integrations (#3216)
  • Loading branch information
andrewtruong authored Dec 18, 2024
1 parent af7e421 commit adb740b
Show file tree
Hide file tree
Showing 19 changed files with 1,206 additions and 685 deletions.
6 changes: 3 additions & 3 deletions tests/integrations/litellm/litellm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from packaging.version import parse as version_parse

import weave
from weave.integrations.litellm.litellm import litellm_patcher
from weave.integrations.litellm.litellm import get_litellm_patcher

# This PR:
# https://github.com/BerriAI/litellm/commit/fe2aa706e8ff4edbcd109897e5da6b83ef6ad693
Expand Down Expand Up @@ -38,9 +38,9 @@ def patch_litellm(request: Any) -> Generator[None, None, None]:
yield
return

litellm_patcher.attempt_patch()
get_litellm_patcher().attempt_patch()
yield
litellm_patcher.undo_patch()
get_litellm_patcher().undo_patch()


@pytest.mark.skip_clickhouse_client # TODO:VCR recording does not seem to allow us to make requests to the clickhouse db in non-recording mode
Expand Down
135 changes: 79 additions & 56 deletions weave/integrations/anthropic/anthropic_sdk.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,26 @@
from __future__ import annotations

import importlib
from collections.abc import AsyncIterator, Iterator
from functools import wraps
from typing import (
TYPE_CHECKING,
Any,
Callable,
Optional,
Union,
)
from typing import TYPE_CHECKING, Any, Callable

import weave
from weave.trace.autopatch import IntegrationSettings, OpSettings
from weave.trace.op_extensions.accumulator import _IteratorWrapper, add_accumulator
from weave.trace.patcher import MultiPatcher, SymbolPatcher
from weave.trace.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher

if TYPE_CHECKING:
from anthropic.lib.streaming import MessageStream
from anthropic.types import Message, MessageStreamEvent

_anthropic_patcher: MultiPatcher | None = None


def anthropic_accumulator(
acc: Optional["Message"],
value: "MessageStreamEvent",
) -> "Message":
acc: Message | None,
value: MessageStreamEvent,
) -> Message:
from anthropic.types import (
ContentBlockDeltaEvent,
Message,
Expand Down Expand Up @@ -73,13 +72,11 @@ def should_use_accumulator(inputs: dict) -> bool:
return isinstance(inputs, dict) and bool(inputs.get("stream"))


def create_wrapper_sync(
name: str,
) -> Callable[[Callable], Callable]:
def create_wrapper_sync(settings: OpSettings) -> Callable[[Callable], Callable]:
def wrapper(fn: Callable) -> Callable:
"We need to do this so we can check if `stream` is used"
op = weave.op()(fn)
op.name = name # type: ignore
op_kwargs = settings.model_dump()
op = weave.op(fn, **op_kwargs)
return add_accumulator(
op, # type: ignore
make_accumulator=lambda inputs: anthropic_accumulator,
Expand All @@ -92,9 +89,7 @@ def wrapper(fn: Callable) -> Callable:
# Surprisingly, the async `client.chat.completions.create` does not pass
# `inspect.iscoroutinefunction`, so we can't dispatch on it and must write
# it manually here...
def create_wrapper_async(
name: str,
) -> Callable[[Callable], Callable]:
def create_wrapper_async(settings: OpSettings) -> Callable[[Callable], Callable]:
def wrapper(fn: Callable) -> Callable:
def _fn_wrapper(fn: Callable) -> Callable:
@wraps(fn)
Expand All @@ -104,8 +99,8 @@ async def _async_wrapper(*args: Any, **kwargs: Any) -> Any:
return _async_wrapper

"We need to do this so we can check if `stream` is used"
op = weave.op()(_fn_wrapper(fn))
op.name = name # type: ignore
op_kwargs = settings.model_dump()
op = weave.op(_fn_wrapper(fn), **op_kwargs)
return add_accumulator(
op, # type: ignore
make_accumulator=lambda inputs: anthropic_accumulator,
Expand All @@ -123,9 +118,9 @@ async def _async_wrapper(*args: Any, **kwargs: Any) -> Any:


def anthropic_stream_accumulator(
acc: Optional["Message"],
value: "MessageStream",
) -> "Message":
acc: Message | None,
value: MessageStream,
) -> Message:
from anthropic.lib.streaming._types import MessageStopEvent

if acc is None:
Expand All @@ -150,7 +145,7 @@ def __getattr__(self, name: str) -> Any:
return object.__getattribute__(self, name)
return getattr(self._iterator_or_ctx_manager, name)

def __stream_text__(self) -> Union[Iterator[str], AsyncIterator[str]]:
def __stream_text__(self) -> Iterator[str] | AsyncIterator[str]:
if isinstance(self._iterator_or_ctx_manager, AsyncIterator):
return self.__async_stream_text__()
else:
Expand All @@ -167,16 +162,14 @@ async def __async_stream_text__(self) -> AsyncIterator[str]: # type: ignore
yield chunk.delta.text # type: ignore

@property
def text_stream(self) -> Union[Iterator[str], AsyncIterator[str]]:
def text_stream(self) -> Iterator[str] | AsyncIterator[str]:
return self.__stream_text__()


def create_stream_wrapper(
name: str,
) -> Callable[[Callable], Callable]:
def create_stream_wrapper(settings: OpSettings) -> Callable[[Callable], Callable]:
def wrapper(fn: Callable) -> Callable:
op = weave.op()(fn)
op.name = name # type: ignore
op_kwargs = settings.model_dump()
op = weave.op(fn, **op_kwargs)
return add_accumulator(
op, # type: ignore
make_accumulator=lambda _: anthropic_stream_accumulator,
Expand All @@ -187,28 +180,58 @@ def wrapper(fn: Callable) -> Callable:
return wrapper


anthropic_patcher = MultiPatcher(
[
# Patch the sync messages.create method for all messages.create methods
SymbolPatcher(
lambda: importlib.import_module("anthropic.resources.messages"),
"Messages.create",
create_wrapper_sync(name="anthropic.Messages.create"),
),
SymbolPatcher(
lambda: importlib.import_module("anthropic.resources.messages"),
"AsyncMessages.create",
create_wrapper_async(name="anthropic.AsyncMessages.create"),
),
SymbolPatcher(
lambda: importlib.import_module("anthropic.resources.messages"),
"Messages.stream",
create_stream_wrapper(name="anthropic.Messages.stream"),
),
SymbolPatcher(
lambda: importlib.import_module("anthropic.resources.messages"),
"AsyncMessages.stream",
create_stream_wrapper(name="anthropic.AsyncMessages.stream"),
),
]
)
def get_anthropic_patcher(
settings: IntegrationSettings | None = None,
) -> MultiPatcher | NoOpPatcher:
if settings is None:
settings = IntegrationSettings()

if not settings.enabled:
return NoOpPatcher()

global _anthropic_patcher
if _anthropic_patcher is not None:
return _anthropic_patcher

base = settings.op_settings

messages_create_settings = base.model_copy(
update={"name": base.name or "anthropic.Messages.create"}
)
async_messages_create_settings = base.model_copy(
update={"name": base.name or "anthropic.AsyncMessages.create"}
)
stream_settings = base.model_copy(
update={"name": base.name or "anthropic.Messages.stream"}
)
async_stream_settings = base.model_copy(
update={"name": base.name or "anthropic.AsyncMessages.stream"}
)

_anthropic_patcher = MultiPatcher(
[
# Patch the sync messages.create method for all messages.create methods
SymbolPatcher(
lambda: importlib.import_module("anthropic.resources.messages"),
"Messages.create",
create_wrapper_sync(messages_create_settings),
),
SymbolPatcher(
lambda: importlib.import_module("anthropic.resources.messages"),
"AsyncMessages.create",
create_wrapper_async(async_messages_create_settings),
),
SymbolPatcher(
lambda: importlib.import_module("anthropic.resources.messages"),
"Messages.stream",
create_stream_wrapper(stream_settings),
),
SymbolPatcher(
lambda: importlib.import_module("anthropic.resources.messages"),
"AsyncMessages.stream",
create_stream_wrapper(async_stream_settings),
),
]
)

return _anthropic_patcher
72 changes: 47 additions & 25 deletions weave/integrations/cerebras/cerebras_sdk.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,26 @@
from __future__ import annotations

import importlib
from functools import wraps
from typing import Any, Callable

import weave
from weave.trace.patcher import MultiPatcher, SymbolPatcher
from weave.trace.autopatch import IntegrationSettings, OpSettings
from weave.trace.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher

_cerebras_patcher: MultiPatcher | None = None


def create_wrapper_sync(
name: str,
) -> Callable[[Callable], Callable]:
def create_wrapper_sync(settings: OpSettings) -> Callable[[Callable], Callable]:
def wrapper(fn: Callable) -> Callable:
op = weave.op()(fn)
op.name = name # type: ignore
op_kwargs = settings.model_dump()
op = weave.op(fn, **op_kwargs)
return op

return wrapper


def create_wrapper_async(
name: str,
) -> Callable[[Callable], Callable]:
def create_wrapper_async(settings: OpSettings) -> Callable[[Callable], Callable]:
def wrapper(fn: Callable) -> Callable:
def _fn_wrapper(fn: Callable) -> Callable:
@wraps(fn)
Expand All @@ -28,24 +29,45 @@ async def _async_wrapper(*args: Any, **kwargs: Any) -> Any:

return _async_wrapper

op = weave.op()(_fn_wrapper(fn))
op.name = name # type: ignore
op_kwargs = settings.model_dump()
op = weave.op(_fn_wrapper(fn), **op_kwargs)
return op

return wrapper


cerebras_patcher = MultiPatcher(
[
SymbolPatcher(
lambda: importlib.import_module("cerebras.cloud.sdk.resources.chat"),
"CompletionsResource.create",
create_wrapper_sync(name="cerebras.chat.completions.create"),
),
SymbolPatcher(
lambda: importlib.import_module("cerebras.cloud.sdk.resources.chat"),
"AsyncCompletionsResource.create",
create_wrapper_async(name="cerebras.chat.completions.create"),
),
]
)
def get_cerebras_patcher(
settings: IntegrationSettings | None = None,
) -> MultiPatcher | NoOpPatcher:
if settings is None:
settings = IntegrationSettings()

if not settings.enabled:
return NoOpPatcher()

global _cerebras_patcher
if _cerebras_patcher is not None:
return _cerebras_patcher

base = settings.op_settings

create_settings = base.model_copy(
update={"name": base.name or "cerebras.chat.completions.create"}
)

_cerebras_patcher = MultiPatcher(
[
SymbolPatcher(
lambda: importlib.import_module("cerebras.cloud.sdk.resources.chat"),
"CompletionsResource.create",
create_wrapper_sync(create_settings),
),
SymbolPatcher(
lambda: importlib.import_module("cerebras.cloud.sdk.resources.chat"),
"AsyncCompletionsResource.create",
create_wrapper_async(create_settings),
),
]
)

return _cerebras_patcher
2 changes: 1 addition & 1 deletion weave/integrations/cohere/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from weave.integrations.cohere.cohere_sdk import cohere_patcher as cohere_patcher
from weave.integrations.cohere.cohere_sdk import get_cohere_patcher # noqa: F401
Loading

0 comments on commit adb740b

Please sign in to comment.