Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewtruong committed Dec 10, 2024
1 parent f1e5801 commit 75ccb53
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 74 deletions.
114 changes: 67 additions & 47 deletions weave/integrations/anthropic/anthropic_sdk.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,20 @@
import dataclasses
import importlib
from collections.abc import AsyncIterator, Iterator
from functools import wraps
from typing import (
TYPE_CHECKING,
Any,
Callable,
Optional,
Union,
)
from typing import TYPE_CHECKING, Any, Callable, Optional, Union

import weave
from weave.trace.autopatch import IntegrationSettings, OpSettings
from weave.trace.op_extensions.accumulator import _IteratorWrapper, add_accumulator
from weave.trace.patcher import MultiPatcher, SymbolPatcher

if TYPE_CHECKING:
from anthropic.lib.streaming import MessageStream
from anthropic.types import Message, MessageStreamEvent

_anthropic_patcher: Optional[MultiPatcher] = None


def anthropic_accumulator(
acc: Optional["Message"],
Expand Down Expand Up @@ -73,13 +71,11 @@ def should_use_accumulator(inputs: dict) -> bool:
return isinstance(inputs, dict) and bool(inputs.get("stream"))


def create_wrapper_sync(
name: str,
) -> Callable[[Callable], Callable]:
def create_wrapper_sync(settings: OpSettings) -> Callable[[Callable], Callable]:
def wrapper(fn: Callable) -> Callable:
"We need to do this so we can check if `stream` is used"
op = weave.op()(fn)
op.name = name # type: ignore
op_kwargs = dataclasses.asdict(settings)
op = weave.op(fn, **op_kwargs)
return add_accumulator(
op, # type: ignore
make_accumulator=lambda inputs: anthropic_accumulator,
Expand All @@ -92,9 +88,7 @@ def wrapper(fn: Callable) -> Callable:
# Surprisingly, the async `client.chat.completions.create` does not pass
# `inspect.iscoroutinefunction`, so we can't dispatch on it and must write
# it manually here...
def create_wrapper_async(
name: str,
) -> Callable[[Callable], Callable]:
def create_wrapper_async(settings: OpSettings) -> Callable[[Callable], Callable]:
def wrapper(fn: Callable) -> Callable:
def _fn_wrapper(fn: Callable) -> Callable:
@wraps(fn)
Expand All @@ -104,8 +98,8 @@ async def _async_wrapper(*args: Any, **kwargs: Any) -> Any:
return _async_wrapper

"We need to do this so we can check if `stream` is used"
op = weave.op()(_fn_wrapper(fn))
op.name = name # type: ignore
op_kwargs = dataclasses.asdict(settings)
op = weave.op(_fn_wrapper(fn), **op_kwargs)
return add_accumulator(
op, # type: ignore
make_accumulator=lambda inputs: anthropic_accumulator,
Expand Down Expand Up @@ -171,12 +165,10 @@ def text_stream(self) -> Union[Iterator[str], AsyncIterator[str]]:
return self.__stream_text__()


def create_stream_wrapper(
name: str,
) -> Callable[[Callable], Callable]:
def create_stream_wrapper(settings: OpSettings) -> Callable[[Callable], Callable]:
def wrapper(fn: Callable) -> Callable:
op = weave.op()(fn)
op.name = name # type: ignore
op_kwargs = dataclasses.asdict(settings)
op = weave.op(fn, **op_kwargs)
return add_accumulator(
op, # type: ignore
make_accumulator=lambda _: anthropic_stream_accumulator,
Expand All @@ -187,28 +179,56 @@ def wrapper(fn: Callable) -> Callable:
return wrapper


anthropic_patcher = MultiPatcher(
[
# Patch the sync messages.create method for all messages.create methods
SymbolPatcher(
lambda: importlib.import_module("anthropic.resources.messages"),
"Messages.create",
create_wrapper_sync(name="anthropic.Messages.create"),
),
SymbolPatcher(
lambda: importlib.import_module("anthropic.resources.messages"),
"AsyncMessages.create",
create_wrapper_async(name="anthropic.AsyncMessages.create"),
),
SymbolPatcher(
lambda: importlib.import_module("anthropic.resources.messages"),
"Messages.stream",
create_stream_wrapper(name="anthropic.Messages.stream"),
),
SymbolPatcher(
lambda: importlib.import_module("anthropic.resources.messages"),
"AsyncMessages.stream",
create_stream_wrapper(name="anthropic.AsyncMessages.stream"),
),
]
)
def get_anthropic_patcher(
settings: Optional[IntegrationSettings] = None,
) -> MultiPatcher:
global _anthropic_patcher

if _anthropic_patcher is not None:
return _anthropic_patcher

if settings is None:
settings = IntegrationSettings()

base = settings.op_settings

completions_create_settings = dataclasses.replace(
base,
name=base.name or "anthropic.Messages.create",
)
async_completions_create_settings = dataclasses.replace(
base,
name=base.name or "anthropic.AsyncMessages.create",
)
stream_settings = dataclasses.replace(
base,
name=base.name or "anthropic.Messages.stream",
)

_anthropic_patcher = MultiPatcher(
[
# Patch the sync messages.create method for all messages.create methods
SymbolPatcher(
lambda: importlib.import_module("anthropic.resources.messages"),
"Messages.create",
create_wrapper_sync(completions_create_settings),
),
SymbolPatcher(
lambda: importlib.import_module("anthropic.resources.messages"),
"AsyncMessages.create",
create_wrapper_async(async_completions_create_settings),
),
SymbolPatcher(
lambda: importlib.import_module("anthropic.resources.messages"),
"Messages.stream",
create_stream_wrapper(stream_settings),
),
SymbolPatcher(
lambda: importlib.import_module("anthropic.resources.messages"),
"AsyncMessages.stream",
create_stream_wrapper(stream_settings),
),
]
)

return _anthropic_patcher
51 changes: 28 additions & 23 deletions weave/integrations/openai/openai_sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,28 +406,33 @@ def get_openai_patcher(settings: Optional[IntegrationSettings] = None) -> MultiP
name=base.name or "openai.beta.chat.completions.parse",
)

symbol_patchers = [
SymbolPatcher(
lambda: importlib.import_module("openai.resources.chat.completions"),
"Completions.create",
create_wrapper_sync(settings=completions_create_settings),
),
SymbolPatcher(
lambda: importlib.import_module("openai.resources.chat.completions"),
"AsyncCompletions.create",
create_wrapper_async(settings=async_completions_create_settings),
),
SymbolPatcher(
lambda: importlib.import_module("openai.resources.beta.chat.completions"),
"Completions.parse",
create_wrapper_sync(settings=completions_parse_settings),
),
SymbolPatcher(
lambda: importlib.import_module("openai.resources.beta.chat.completions"),
"AsyncCompletions.parse",
create_wrapper_async(settings=async_completions_parse_settings),
),
]
_openai_patcher = MultiPatcher(symbol_patchers) # type: ignore
_openai_patcher = MultiPatcher(
[
SymbolPatcher(
lambda: importlib.import_module("openai.resources.chat.completions"),
"Completions.create",
create_wrapper_sync(settings=completions_create_settings),
),
SymbolPatcher(
lambda: importlib.import_module("openai.resources.chat.completions"),
"AsyncCompletions.create",
create_wrapper_async(settings=async_completions_create_settings),
),
SymbolPatcher(
lambda: importlib.import_module(
"openai.resources.beta.chat.completions"
),
"Completions.parse",
create_wrapper_sync(settings=completions_parse_settings),
),
SymbolPatcher(
lambda: importlib.import_module(
"openai.resources.beta.chat.completions"
),
"AsyncCompletions.parse",
create_wrapper_async(settings=async_completions_parse_settings),
),
]
)

return _openai_patcher
8 changes: 4 additions & 4 deletions weave/trace/autopatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


def autopatch(settings: AutopatchSettings | None = None) -> None:
from weave.integrations.anthropic.anthropic_sdk import anthropic_patcher
from weave.integrations.anthropic.anthropic_sdk import get_anthropic_patcher
from weave.integrations.cerebras.cerebras_sdk import cerebras_patcher
from weave.integrations.cohere.cohere_sdk import cohere_patcher
from weave.integrations.dspy.dspy_sdk import dspy_patcher
Expand All @@ -38,7 +38,7 @@ def autopatch(settings: AutopatchSettings | None = None) -> None:
litellm_patcher.attempt_patch()
llamaindex_patcher.attempt_patch()
langchain_patcher.attempt_patch()
anthropic_patcher.attempt_patch()
get_anthropic_patcher(settings.anthropic).attempt_patch()
groq_patcher.attempt_patch()
instructor_patcher.attempt_patch()
dspy_patcher.attempt_patch()
Expand All @@ -50,7 +50,7 @@ def autopatch(settings: AutopatchSettings | None = None) -> None:


def reset_autopatch() -> None:
from weave.integrations.anthropic.anthropic_sdk import anthropic_patcher
from weave.integrations.anthropic.anthropic_sdk import get_anthropic_patcher
from weave.integrations.cerebras.cerebras_sdk import cerebras_patcher
from weave.integrations.cohere.cohere_sdk import cohere_patcher
from weave.integrations.dspy.dspy_sdk import dspy_patcher
Expand All @@ -72,7 +72,7 @@ def reset_autopatch() -> None:
litellm_patcher.undo_patch()
llamaindex_patcher.undo_patch()
langchain_patcher.undo_patch()
anthropic_patcher.undo_patch()
get_anthropic_patcher().undo_patch()
groq_patcher.undo_patch()
instructor_patcher.undo_patch()
dspy_patcher.undo_patch()
Expand Down

0 comments on commit 75ccb53

Please sign in to comment.