From 1112450c12f8c28753cc39f29226367f78b8ba4f Mon Sep 17 00:00:00 2001 From: Thomas Capelle Date: Tue, 6 Aug 2024 17:34:52 +0200 Subject: [PATCH] ruff and mypy --- weave/integrations/anthropic/anthropic_sdk.py | 9 ++++----- weave/trace/op_extensions/accumulator.py | 4 ++-- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/weave/integrations/anthropic/anthropic_sdk.py b/weave/integrations/anthropic/anthropic_sdk.py index 915f3a0a5f0..9c284ff5bed 100644 --- a/weave/integrations/anthropic/anthropic_sdk.py +++ b/weave/integrations/anthropic/anthropic_sdk.py @@ -119,6 +119,7 @@ async def _async_wrapper( ## This code handles both cases by patching the _IteratorWrapper ## and adding a text_stream property to it. + def anthropic_stream_accumulator( acc: typing.Optional["Message"], value: "MessageStream", @@ -132,8 +133,6 @@ def anthropic_stream_accumulator( return acc - - class AnthropicIteratorWrapper(_IteratorWrapper): def __getattr__(self, name: str) -> Any: """Delegate all other attributes to the wrapped iterator.""" @@ -155,12 +154,12 @@ def __stream_text__(self) -> Union[Iterator[str], AsyncIterator[str]]: else: return self.__sync_stream_text__() - def __sync_stream_text__(self) -> Iterator[str]: + def __sync_stream_text__(self) -> Iterator[str]: # type: ignore[attr-defined] for chunk in self: - if chunk.type == "content_block_delta" and chunk.delta.type == "text_delta": + if chunk.type == "content_block_delta" and chunk.delta.type == "text_delta": # type: ignore[attr-defined] yield chunk.delta.text - async def __async_stream_text__(self) -> AsyncIterator[str]: + async def __async_stream_text__(self) -> AsyncIterator[str]: # type: ignore[attr-defined] async for chunk in self: if chunk.type == "content_block_delta" and chunk.delta.type == "text_delta": yield chunk.delta.text diff --git a/weave/trace/op_extensions/accumulator.py b/weave/trace/op_extensions/accumulator.py index 7f385449a28..fab4b7548a5 100644 --- a/weave/trace/op_extensions/accumulator.py +++ b/weave/trace/op_extensions/accumulator.py @@ -157,7 +157,7 @@ def add_accumulator( *, should_accumulate: Optional[Callable[[Dict], bool]] = None, on_finish_post_processor: Optional[Callable[[Any], Any]] = None, - iterator_wrapper: Type[_IteratorWrapper[Any]] = _IteratorWrapper, + iterator_wrapper: Type[_IteratorWrapper] = _IteratorWrapper, ) -> Op: """This is to be used internally only - specifically designed for integrations with streaming libraries. @@ -214,7 +214,7 @@ def _build_iterator_from_accumulator_for_op( value: Iterator[V], accumulator: Callable, on_finish: FinishCallbackType, - iterator_wrapper: "_IteratorWrapper" = _IteratorWrapper, + iterator_wrapper: Type["_IteratorWrapper"] = _IteratorWrapper, ) -> "_IteratorWrapper": acc: _Accumulator = _Accumulator(accumulator)