diff --git a/weave/integrations/anthropic/anthropic_sdk.py b/weave/integrations/anthropic/anthropic_sdk.py index c991c644eb9..915f3a0a5f0 100644 --- a/weave/integrations/anthropic/anthropic_sdk.py +++ b/weave/integrations/anthropic/anthropic_sdk.py @@ -1,10 +1,12 @@ import importlib import typing from functools import wraps +from typing import Any, Union + +from typing_extensions import AsyncIterator, Iterator import weave -from weave.trace.op_extensions.accumulator import (_IteratorWrapper, - add_accumulator) +from weave.trace.op_extensions.accumulator import _IteratorWrapper, add_accumulator from weave.trace.patcher import MultiPatcher, SymbolPatcher if typing.TYPE_CHECKING: @@ -16,8 +18,13 @@ def anthropic_accumulator( acc: typing.Optional["Message"], value: "MessageStreamEvent", ) -> "Message": - from anthropic.types import (ContentBlockDeltaEvent, Message, - MessageDeltaEvent, TextBlock, Usage) + from anthropic.types import ( + ContentBlockDeltaEvent, + Message, + MessageDeltaEvent, + TextBlock, + Usage, + ) if acc is None: if hasattr(value, "message"): @@ -125,9 +132,6 @@ def anthropic_stream_accumulator( return acc -from typing import Any, Union - -from typing_extensions import AsyncIterator, Iterator class AnthropicIteratorWrapper(_IteratorWrapper): @@ -176,7 +180,7 @@ def wrapper(fn: typing.Callable) -> typing.Callable: op, # type: ignore make_accumulator=lambda _: anthropic_stream_accumulator, should_accumulate=lambda _: True, - iterator_wrapper=AnthropicIteratorWrapper, + iterator_wrapper=AnthropicIteratorWrapper, # type: ignore ) return wrapper diff --git a/weave/integrations/anthropic/anthropic_test.py b/weave/integrations/anthropic/anthropic_test.py index 453c21804d0..5030930dda2 100644 --- a/weave/integrations/anthropic/anthropic_test.py +++ b/weave/integrations/anthropic/anthropic_test.py @@ -2,11 +2,11 @@ from typing import Any import pytest +from anthropic import Anthropic, AsyncAnthropic + import weave from weave.trace_server import trace_server_interface as tsi -from anthropic import Anthropic, AsyncAnthropic - model = "claude-3-haiku-20240307" # model = "claude-3-opus-20240229" diff --git a/weave/trace/op_extensions/accumulator.py b/weave/trace/op_extensions/accumulator.py index 3d8f0c232db..7f385449a28 100644 --- a/weave/trace/op_extensions/accumulator.py +++ b/weave/trace/op_extensions/accumulator.py @@ -9,6 +9,7 @@ Generic, Iterator, Optional, + Type, TypeVar, Union, ) @@ -156,7 +157,7 @@ def add_accumulator( *, should_accumulate: Optional[Callable[[Dict], bool]] = None, on_finish_post_processor: Optional[Callable[[Any], Any]] = None, - iterator_wrapper: "_IteratorWrapper" = _IteratorWrapper, + iterator_wrapper: Type[_IteratorWrapper[Any]] = _IteratorWrapper, ) -> Op: """This is to be used internally only - specifically designed for integrations with streaming libraries. @@ -182,7 +183,6 @@ def simple_list_accumulator(acc, value): acc.append(value) return acc add_accumulator(fn, simple_list_accumulator) # returns the op with `list(range(9, -1, -1))` as output - ``` """ def on_output(