From adb740bec5304df1130b92b0ce4b90bb909f6e34 Mon Sep 17 00:00:00 2001 From: Andrew Truong Date: Tue, 17 Dec 2024 20:20:45 -0500 Subject: [PATCH] feat(weave): Support op configuration for autopatched functions for remaining integrations (#3216) --- tests/integrations/litellm/litellm_test.py | 6 +- weave/integrations/anthropic/anthropic_sdk.py | 135 +++-- weave/integrations/cerebras/cerebras_sdk.py | 72 ++- weave/integrations/cohere/__init__.py | 2 +- weave/integrations/cohere/cohere_sdk.py | 198 +++++--- weave/integrations/dspy/dspy_sdk.py | 477 +++++++++++------- .../google_ai_studio/google_ai_studio_sdk.py | 129 +++-- weave/integrations/groq/groq_sdk.py | 77 ++- .../instructor/instructor_iterable_utils.py | 14 +- .../instructor/instructor_partial_utils.py | 7 +- .../integrations/instructor/instructor_sdk.py | 84 ++- weave/integrations/litellm/litellm.py | 73 ++- weave/integrations/mistral/__init__.py | 4 +- weave/integrations/mistral/v0/mistral.py | 124 +++-- weave/integrations/mistral/v1/mistral.py | 112 ++-- weave/integrations/notdiamond/__init__.py | 2 +- weave/integrations/notdiamond/tracing.py | 135 +++-- weave/integrations/vertexai/vertexai_sdk.py | 126 +++-- weave/trace/autopatch.py | 114 ++--- 19 files changed, 1206 insertions(+), 685 deletions(-) diff --git a/tests/integrations/litellm/litellm_test.py b/tests/integrations/litellm/litellm_test.py index 8cc0966c4768..ffd1094e6534 100644 --- a/tests/integrations/litellm/litellm_test.py +++ b/tests/integrations/litellm/litellm_test.py @@ -8,7 +8,7 @@ from packaging.version import parse as version_parse import weave -from weave.integrations.litellm.litellm import litellm_patcher +from weave.integrations.litellm.litellm import get_litellm_patcher # This PR: # https://github.com/BerriAI/litellm/commit/fe2aa706e8ff4edbcd109897e5da6b83ef6ad693 @@ -38,9 +38,9 @@ def patch_litellm(request: Any) -> Generator[None, None, None]: yield return - litellm_patcher.attempt_patch() + get_litellm_patcher().attempt_patch() yield - litellm_patcher.undo_patch() + get_litellm_patcher().undo_patch() @pytest.mark.skip_clickhouse_client # TODO:VCR recording does not seem to allow us to make requests to the clickhouse db in non-recording mode diff --git a/weave/integrations/anthropic/anthropic_sdk.py b/weave/integrations/anthropic/anthropic_sdk.py index 6e33e1a39062..9cd06f532594 100644 --- a/weave/integrations/anthropic/anthropic_sdk.py +++ b/weave/integrations/anthropic/anthropic_sdk.py @@ -1,27 +1,26 @@ +from __future__ import annotations + import importlib from collections.abc import AsyncIterator, Iterator from functools import wraps -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Optional, - Union, -) +from typing import TYPE_CHECKING, Any, Callable import weave +from weave.trace.autopatch import IntegrationSettings, OpSettings from weave.trace.op_extensions.accumulator import _IteratorWrapper, add_accumulator -from weave.trace.patcher import MultiPatcher, SymbolPatcher +from weave.trace.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher if TYPE_CHECKING: from anthropic.lib.streaming import MessageStream from anthropic.types import Message, MessageStreamEvent +_anthropic_patcher: MultiPatcher | None = None + def anthropic_accumulator( - acc: Optional["Message"], - value: "MessageStreamEvent", -) -> "Message": + acc: Message | None, + value: MessageStreamEvent, +) -> Message: from anthropic.types import ( ContentBlockDeltaEvent, Message, @@ -73,13 +72,11 @@ def should_use_accumulator(inputs: dict) -> bool: return isinstance(inputs, dict) and bool(inputs.get("stream")) -def create_wrapper_sync( - name: str, -) -> Callable[[Callable], Callable]: +def create_wrapper_sync(settings: OpSettings) -> Callable[[Callable], Callable]: def wrapper(fn: Callable) -> Callable: "We need to do this so we can check if `stream` is used" - op = weave.op()(fn) - op.name = name # type: ignore + op_kwargs = settings.model_dump() + op = weave.op(fn, **op_kwargs) return add_accumulator( op, # type: ignore make_accumulator=lambda inputs: anthropic_accumulator, @@ -92,9 +89,7 @@ def wrapper(fn: Callable) -> Callable: # Surprisingly, the async `client.chat.completions.create` does not pass # `inspect.iscoroutinefunction`, so we can't dispatch on it and must write # it manually here... -def create_wrapper_async( - name: str, -) -> Callable[[Callable], Callable]: +def create_wrapper_async(settings: OpSettings) -> Callable[[Callable], Callable]: def wrapper(fn: Callable) -> Callable: def _fn_wrapper(fn: Callable) -> Callable: @wraps(fn) @@ -104,8 +99,8 @@ async def _async_wrapper(*args: Any, **kwargs: Any) -> Any: return _async_wrapper "We need to do this so we can check if `stream` is used" - op = weave.op()(_fn_wrapper(fn)) - op.name = name # type: ignore + op_kwargs = settings.model_dump() + op = weave.op(_fn_wrapper(fn), **op_kwargs) return add_accumulator( op, # type: ignore make_accumulator=lambda inputs: anthropic_accumulator, @@ -123,9 +118,9 @@ async def _async_wrapper(*args: Any, **kwargs: Any) -> Any: def anthropic_stream_accumulator( - acc: Optional["Message"], - value: "MessageStream", -) -> "Message": + acc: Message | None, + value: MessageStream, +) -> Message: from anthropic.lib.streaming._types import MessageStopEvent if acc is None: @@ -150,7 +145,7 @@ def __getattr__(self, name: str) -> Any: return object.__getattribute__(self, name) return getattr(self._iterator_or_ctx_manager, name) - def __stream_text__(self) -> Union[Iterator[str], AsyncIterator[str]]: + def __stream_text__(self) -> Iterator[str] | AsyncIterator[str]: if isinstance(self._iterator_or_ctx_manager, AsyncIterator): return self.__async_stream_text__() else: @@ -167,16 +162,14 @@ async def __async_stream_text__(self) -> AsyncIterator[str]: # type: ignore yield chunk.delta.text # type: ignore @property - def text_stream(self) -> Union[Iterator[str], AsyncIterator[str]]: + def text_stream(self) -> Iterator[str] | AsyncIterator[str]: return self.__stream_text__() -def create_stream_wrapper( - name: str, -) -> Callable[[Callable], Callable]: +def create_stream_wrapper(settings: OpSettings) -> Callable[[Callable], Callable]: def wrapper(fn: Callable) -> Callable: - op = weave.op()(fn) - op.name = name # type: ignore + op_kwargs = settings.model_dump() + op = weave.op(fn, **op_kwargs) return add_accumulator( op, # type: ignore make_accumulator=lambda _: anthropic_stream_accumulator, @@ -187,28 +180,58 @@ def wrapper(fn: Callable) -> Callable: return wrapper -anthropic_patcher = MultiPatcher( - [ - # Patch the sync messages.create method for all messages.create methods - SymbolPatcher( - lambda: importlib.import_module("anthropic.resources.messages"), - "Messages.create", - create_wrapper_sync(name="anthropic.Messages.create"), - ), - SymbolPatcher( - lambda: importlib.import_module("anthropic.resources.messages"), - "AsyncMessages.create", - create_wrapper_async(name="anthropic.AsyncMessages.create"), - ), - SymbolPatcher( - lambda: importlib.import_module("anthropic.resources.messages"), - "Messages.stream", - create_stream_wrapper(name="anthropic.Messages.stream"), - ), - SymbolPatcher( - lambda: importlib.import_module("anthropic.resources.messages"), - "AsyncMessages.stream", - create_stream_wrapper(name="anthropic.AsyncMessages.stream"), - ), - ] -) +def get_anthropic_patcher( + settings: IntegrationSettings | None = None, +) -> MultiPatcher | NoOpPatcher: + if settings is None: + settings = IntegrationSettings() + + if not settings.enabled: + return NoOpPatcher() + + global _anthropic_patcher + if _anthropic_patcher is not None: + return _anthropic_patcher + + base = settings.op_settings + + messages_create_settings = base.model_copy( + update={"name": base.name or "anthropic.Messages.create"} + ) + async_messages_create_settings = base.model_copy( + update={"name": base.name or "anthropic.AsyncMessages.create"} + ) + stream_settings = base.model_copy( + update={"name": base.name or "anthropic.Messages.stream"} + ) + async_stream_settings = base.model_copy( + update={"name": base.name or "anthropic.AsyncMessages.stream"} + ) + + _anthropic_patcher = MultiPatcher( + [ + # Patch the sync messages.create method for all messages.create methods + SymbolPatcher( + lambda: importlib.import_module("anthropic.resources.messages"), + "Messages.create", + create_wrapper_sync(messages_create_settings), + ), + SymbolPatcher( + lambda: importlib.import_module("anthropic.resources.messages"), + "AsyncMessages.create", + create_wrapper_async(async_messages_create_settings), + ), + SymbolPatcher( + lambda: importlib.import_module("anthropic.resources.messages"), + "Messages.stream", + create_stream_wrapper(stream_settings), + ), + SymbolPatcher( + lambda: importlib.import_module("anthropic.resources.messages"), + "AsyncMessages.stream", + create_stream_wrapper(async_stream_settings), + ), + ] + ) + + return _anthropic_patcher diff --git a/weave/integrations/cerebras/cerebras_sdk.py b/weave/integrations/cerebras/cerebras_sdk.py index bdce368290e0..a2096a184e79 100644 --- a/weave/integrations/cerebras/cerebras_sdk.py +++ b/weave/integrations/cerebras/cerebras_sdk.py @@ -1,25 +1,26 @@ +from __future__ import annotations + import importlib from functools import wraps from typing import Any, Callable import weave -from weave.trace.patcher import MultiPatcher, SymbolPatcher +from weave.trace.autopatch import IntegrationSettings, OpSettings +from weave.trace.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher + +_cerebras_patcher: MultiPatcher | None = None -def create_wrapper_sync( - name: str, -) -> Callable[[Callable], Callable]: +def create_wrapper_sync(settings: OpSettings) -> Callable[[Callable], Callable]: def wrapper(fn: Callable) -> Callable: - op = weave.op()(fn) - op.name = name # type: ignore + op_kwargs = settings.model_dump() + op = weave.op(fn, **op_kwargs) return op return wrapper -def create_wrapper_async( - name: str, -) -> Callable[[Callable], Callable]: +def create_wrapper_async(settings: OpSettings) -> Callable[[Callable], Callable]: def wrapper(fn: Callable) -> Callable: def _fn_wrapper(fn: Callable) -> Callable: @wraps(fn) @@ -28,24 +29,45 @@ async def _async_wrapper(*args: Any, **kwargs: Any) -> Any: return _async_wrapper - op = weave.op()(_fn_wrapper(fn)) - op.name = name # type: ignore + op_kwargs = settings.model_dump() + op = weave.op(_fn_wrapper(fn), **op_kwargs) return op return wrapper -cerebras_patcher = MultiPatcher( - [ - SymbolPatcher( - lambda: importlib.import_module("cerebras.cloud.sdk.resources.chat"), - "CompletionsResource.create", - create_wrapper_sync(name="cerebras.chat.completions.create"), - ), - SymbolPatcher( - lambda: importlib.import_module("cerebras.cloud.sdk.resources.chat"), - "AsyncCompletionsResource.create", - create_wrapper_async(name="cerebras.chat.completions.create"), - ), - ] -) +def get_cerebras_patcher( + settings: IntegrationSettings | None = None, +) -> MultiPatcher | NoOpPatcher: + if settings is None: + settings = IntegrationSettings() + + if not settings.enabled: + return NoOpPatcher() + + global _cerebras_patcher + if _cerebras_patcher is not None: + return _cerebras_patcher + + base = settings.op_settings + + create_settings = base.model_copy( + update={"name": base.name or "cerebras.chat.completions.create"} + ) + + _cerebras_patcher = MultiPatcher( + [ + SymbolPatcher( + lambda: importlib.import_module("cerebras.cloud.sdk.resources.chat"), + "CompletionsResource.create", + create_wrapper_sync(create_settings), + ), + SymbolPatcher( + lambda: importlib.import_module("cerebras.cloud.sdk.resources.chat"), + "AsyncCompletionsResource.create", + create_wrapper_async(create_settings), + ), + ] + ) + + return _cerebras_patcher diff --git a/weave/integrations/cohere/__init__.py b/weave/integrations/cohere/__init__.py index 45f925b6eeac..288cce91aaea 100644 --- a/weave/integrations/cohere/__init__.py +++ b/weave/integrations/cohere/__init__.py @@ -1 +1 @@ -from weave.integrations.cohere.cohere_sdk import cohere_patcher as cohere_patcher +from weave.integrations.cohere.cohere_sdk import get_cohere_patcher # noqa: F401 diff --git a/weave/integrations/cohere/cohere_sdk.py b/weave/integrations/cohere/cohere_sdk.py index b0a5944795b4..a9b216c070b5 100644 --- a/weave/integrations/cohere/cohere_sdk.py +++ b/weave/integrations/cohere/cohere_sdk.py @@ -1,20 +1,23 @@ +from __future__ import annotations + import importlib from functools import wraps -from typing import TYPE_CHECKING, Any, Callable, Optional +from typing import TYPE_CHECKING, Any, Callable import weave +from weave.trace.autopatch import IntegrationSettings, OpSettings from weave.trace.op_extensions.accumulator import add_accumulator -from weave.trace.patcher import MultiPatcher, SymbolPatcher +from weave.trace.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher if TYPE_CHECKING: from cohere.types.non_streamed_chat_response import NonStreamedChatResponse from cohere.v2.types.non_streamed_chat_response2 import NonStreamedChatResponse2 -def cohere_accumulator( - acc: Optional[dict], - value: Any, -) -> "NonStreamedChatResponse": +_cohere_patcher: MultiPatcher | None = None + + +def cohere_accumulator(acc: dict | None, value: Any) -> NonStreamedChatResponse: # don't need to accumulate, is build-in by cohere! # https://docs.cohere.com/docs/streaming # A stream-end event is the final event of the stream, and is returned only when streaming is finished. @@ -31,10 +34,7 @@ def cohere_accumulator( return acc -def cohere_accumulator_v2( - acc: Optional[dict], - value: Any, -) -> "NonStreamedChatResponse2": +def cohere_accumulator_v2(acc: dict | None, value: Any) -> NonStreamedChatResponse2: from cohere.v2.types.assistant_message_response import AssistantMessageResponse from cohere.v2.types.non_streamed_chat_response2 import NonStreamedChatResponse2 @@ -86,16 +86,16 @@ def _accumulate_content( return acc -def cohere_wrapper(name: str) -> Callable: +def cohere_wrapper(settings: OpSettings) -> Callable: def wrapper(fn: Callable) -> Callable: - op = weave.op(fn) - op.name = name # type: ignore + op_kwargs = settings.model_dump() + op = weave.op(fn, **op_kwargs) return op return wrapper -def cohere_wrapper_v2(name: str) -> Callable: +def cohere_wrapper_v2(settings: OpSettings) -> Callable: def wrapper(fn: Callable) -> Callable: def _post_process_response(fn: Callable) -> Any: @wraps(fn) @@ -122,14 +122,14 @@ def _wrapper(*args: Any, **kwargs: Any) -> Any: return _wrapper - op = weave.op(_post_process_response(fn)) - op.name = name # type: ignore + op_kwargs = settings.model_dump() + op = weave.op(_post_process_response(fn), **op_kwargs) return op return wrapper -def cohere_wrapper_async_v2(name: str) -> Callable: +def cohere_wrapper_async_v2(settings: OpSettings) -> Callable: def wrapper(fn: Callable) -> Callable: def _post_process_response(fn: Callable) -> Any: @wraps(fn) @@ -156,81 +156,119 @@ async def _wrapper(*args: Any, **kwargs: Any) -> Any: return _wrapper - op = weave.op(_post_process_response(fn)) - op.name = name # type: ignore + op_kwargs = settings.model_dump() + op = weave.op(_post_process_response(fn), **op_kwargs) return op return wrapper -def cohere_stream_wrapper(name: str) -> Callable: +def cohere_stream_wrapper(settings: OpSettings) -> Callable: def wrapper(fn: Callable) -> Callable: - op = weave.op(fn) - op.name = name # type: ignore - return add_accumulator(op, lambda inputs: cohere_accumulator) # type: ignore + op_kwargs = settings.model_dump() + op = weave.op(fn, **op_kwargs) + return add_accumulator(op, lambda inputs: cohere_accumulator) return wrapper -def cohere_stream_wrapper_v2(name: str) -> Callable: +def cohere_stream_wrapper_v2(settings: OpSettings) -> Callable: def wrapper(fn: Callable) -> Callable: - op = weave.op(fn) - op.name = name # type: ignore - return add_accumulator( - op, make_accumulator=lambda inputs: cohere_accumulator_v2 - ) + op_kwargs = settings.model_dump() + op = weave.op(fn, **op_kwargs) + return add_accumulator(op, lambda inputs: cohere_accumulator_v2) return wrapper -cohere_patcher = MultiPatcher( - [ - SymbolPatcher( - lambda: importlib.import_module("cohere"), - "Client.chat", - cohere_wrapper("cohere.Client.chat"), - ), - # Patch the async chat method - SymbolPatcher( - lambda: importlib.import_module("cohere"), - "AsyncClient.chat", - cohere_wrapper("cohere.AsyncClient.chat"), - ), - # Add patch for chat_stream method - SymbolPatcher( - lambda: importlib.import_module("cohere"), - "Client.chat_stream", - cohere_stream_wrapper("cohere.Client.chat_stream"), - ), - # Add patch for async chat_stream method - SymbolPatcher( - lambda: importlib.import_module("cohere"), - "AsyncClient.chat_stream", - cohere_stream_wrapper("cohere.AsyncClient.chat_stream"), - ), - # Add patch for cohere v2 - SymbolPatcher( - lambda: importlib.import_module("cohere"), - "ClientV2.chat", - cohere_wrapper_v2("cohere.ClientV2.chat"), - ), - # Add patch for cohre v2 async chat method - SymbolPatcher( - lambda: importlib.import_module("cohere"), - "AsyncClientV2.chat", - cohere_wrapper_async_v2("cohere.AsyncClientV2.chat"), - ), - # Add patch for chat_stream method v2 - SymbolPatcher( - lambda: importlib.import_module("cohere"), - "ClientV2.chat_stream", - cohere_stream_wrapper_v2("cohere.ClientV2.chat_stream"), - ), - # Add patch for async chat_stream method v2 - SymbolPatcher( - lambda: importlib.import_module("cohere"), - "AsyncClientV2.chat_stream", - cohere_stream_wrapper_v2("cohere.AsyncClientV2.chat_stream"), - ), - ] -) +def get_cohere_patcher( + settings: IntegrationSettings | None = None, +) -> MultiPatcher | NoOpPatcher: + if settings is None: + settings = IntegrationSettings() + + if not settings.enabled: + return NoOpPatcher() + + global _cohere_patcher + if _cohere_patcher is not None: + return _cohere_patcher + + base = settings.op_settings + + chat_settings = base.model_copy(update={"name": base.name or "cohere.Client.chat"}) + async_chat_settings = base.model_copy( + update={"name": base.name or "cohere.AsyncClient.chat"} + ) + chat_stream_settings = base.model_copy( + update={"name": base.name or "cohere.Client.chat_stream"} + ) + async_chat_stream_settings = base.model_copy( + update={"name": base.name or "cohere.AsyncClient.chat_stream"} + ) + chat_v2_settings = base.model_copy( + update={"name": base.name or "cohere.ClientV2.chat"} + ) + async_chat_v2_settings = base.model_copy( + update={"name": base.name or "cohere.AsyncClientV2.chat"} + ) + chat_stream_v2_settings = base.model_copy( + update={"name": base.name or "cohere.ClientV2.chat_stream"} + ) + async_chat_stream_v2_settings = base.model_copy( + update={"name": base.name or "cohere.AsyncClientV2.chat_stream"} + ) + + _cohere_patcher = MultiPatcher( + [ + SymbolPatcher( + lambda: importlib.import_module("cohere"), + "Client.chat", + cohere_wrapper(chat_settings), + ), + # Patch the async chat method + SymbolPatcher( + lambda: importlib.import_module("cohere"), + "AsyncClient.chat", + cohere_wrapper(async_chat_settings), + ), + # Add patch for chat_stream method + SymbolPatcher( + lambda: importlib.import_module("cohere"), + "Client.chat_stream", + cohere_stream_wrapper(chat_stream_settings), + ), + # Add patch for async chat_stream method + SymbolPatcher( + lambda: importlib.import_module("cohere"), + "AsyncClient.chat_stream", + cohere_stream_wrapper(async_chat_stream_settings), + ), + # Add patch for cohere v2 + SymbolPatcher( + lambda: importlib.import_module("cohere"), + "ClientV2.chat", + cohere_wrapper_v2(chat_v2_settings), + ), + # Add patch for cohre v2 async chat method + SymbolPatcher( + lambda: importlib.import_module("cohere"), + "AsyncClientV2.chat", + cohere_wrapper_async_v2(async_chat_v2_settings), + ), + # Add patch for chat_stream method v2 + SymbolPatcher( + lambda: importlib.import_module("cohere"), + "ClientV2.chat_stream", + cohere_stream_wrapper_v2(chat_stream_v2_settings), + ), + # Add patch for async chat_stream method v2 + SymbolPatcher( + lambda: importlib.import_module("cohere"), + "AsyncClientV2.chat_stream", + cohere_stream_wrapper_v2(async_chat_stream_v2_settings), + ), + ] + ) + + return _cohere_patcher diff --git a/weave/integrations/dspy/dspy_sdk.py b/weave/integrations/dspy/dspy_sdk.py index b205c77a4588..25293b0f4947 100644 --- a/weave/integrations/dspy/dspy_sdk.py +++ b/weave/integrations/dspy/dspy_sdk.py @@ -1,216 +1,331 @@ +from __future__ import annotations + import importlib from typing import Callable import weave -from weave.trace.patcher import MultiPatcher, SymbolPatcher +from weave.trace.autopatch import IntegrationSettings, OpSettings +from weave.trace.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher + +_dspy_patcher: MultiPatcher | None = None -def dspy_wrapper(name: str) -> Callable[[Callable], Callable]: +def dspy_wrapper(settings: OpSettings) -> Callable[[Callable], Callable]: def wrapper(fn: Callable) -> Callable: - op = weave.op()(fn) - op.name = name # type: ignore + op_kwargs = settings.model_dump() + op = weave.op(fn, **op_kwargs) return op return wrapper def dspy_get_patched_lm_functions( - base_symbol: str, lm_class_name: str + base_symbol: str, lm_class_name: str, settings: OpSettings ) -> list[SymbolPatcher]: patchable_functional_attributes = [ "basic_request", "request", "__call__", ] + basic_request_settings = settings.model_copy( + update={"name": settings.name or f"{base_symbol}.{lm_class_name}.basic_request"} + ) + request_settings = settings.model_copy( + update={"name": settings.name or f"{base_symbol}.{lm_class_name}.request"} + ) + call_settings = settings.model_copy( + update={"name": settings.name or f"{base_symbol}.{lm_class_name}"} + ) return [ SymbolPatcher( get_base_symbol=lambda: importlib.import_module(base_symbol), attribute_name=f"{lm_class_name}.basic_request", - make_new_value=dspy_wrapper(f"dspy.{lm_class_name}.basic_request"), + make_new_value=dspy_wrapper(basic_request_settings), ), SymbolPatcher( get_base_symbol=lambda: importlib.import_module(base_symbol), attribute_name=f"{lm_class_name}.request", - make_new_value=dspy_wrapper(f"dspy.{lm_class_name}.request"), + make_new_value=dspy_wrapper(request_settings), ), SymbolPatcher( get_base_symbol=lambda: importlib.import_module(base_symbol), attribute_name=f"{lm_class_name}.__call__", - make_new_value=dspy_wrapper(f"dspy.{lm_class_name}"), + make_new_value=dspy_wrapper(call_settings), + ), + ] + + +def get_dspy_patcher( + settings: IntegrationSettings | None = None, +) -> MultiPatcher | NoOpPatcher: + if settings is None: + settings = IntegrationSettings() + + if not settings.enabled: + return NoOpPatcher() + + global _dspy_patcher + if _dspy_patcher is not None: + return _dspy_patcher + + base = settings.op_settings + + predict_call_settings = base.model_copy( + update={"name": base.name or "dspy.Predict"} + ) + predict_forward_settings = base.model_copy( + update={"name": base.name or "dspy.Predict.forward"} + ) + typed_predictor_call_settings = base.model_copy( + update={"name": base.name or "dspy.TypedPredictor"} + ) + typed_predictor_forward_settings = base.model_copy( + update={"name": base.name or "dspy.TypedPredictor.forward"} + ) + module_call_settings = base.model_copy(update={"name": base.name or "dspy.Module"}) + typed_chain_of_thought_call_settings = base.model_copy( + update={"name": base.name or "dspy.TypedChainOfThought"} + ) + retrieve_call_settings = base.model_copy( + update={"name": base.name or "dspy.Retrieve"} + ) + retrieve_forward_settings = base.model_copy( + update={"name": base.name or "dspy.Retrieve.forward"} + ) + evaluate_call_settings = base.model_copy( + update={"name": base.name or "dspy.evaluate.Evaluate"} + ) + bootstrap_few_shot_compile_settings = base.model_copy( + update={"name": base.name or "dspy.teleprompt.BootstrapFewShot.compile"} + ) + copro_compile_settings = base.model_copy( + update={"name": base.name or "dspy.teleprompt.COPRO.compile"} + ) + ensemble_compile_settings = base.model_copy( + update={"name": base.name or "dspy.teleprompt.Ensemble.compile"} + ) + bootstrap_finetune_compile_settings = base.model_copy( + update={"name": base.name or "dspy.teleprompt.BootstrapFinetune.compile"} + ) + knn_few_shot_compile_settings = base.model_copy( + update={"name": base.name or "dspy.teleprompt.KNNFewShot.compile"} + ) + mipro_compile_settings = base.model_copy( + update={"name": base.name or "dspy.teleprompt.MIPRO.compile"} + ) + bootstrap_few_shot_with_random_search_compile_settings = base.model_copy( + update={ + "name": base.name + or "dspy.teleprompt.BootstrapFewShotWithRandomSearch.compile" + } + ) + signature_optimizer_compile_settings = base.model_copy( + update={"name": base.name or "dspy.teleprompt.SignatureOptimizer.compile"} + ) + bayesian_signature_optimizer_compile_settings = base.model_copy( + update={ + "name": base.name or "dspy.teleprompt.BayesianSignatureOptimizer.compile" + } + ) + signature_opt_typed_optimize_signature_settings = base.model_copy( + update={ + "name": base.name + or "dspy.teleprompt.signature_opt_typed.optimize_signature" + } + ) + bootstrap_few_shot_with_optuna_compile_settings = base.model_copy( + update={ + "name": base.name or "dspy.teleprompt.BootstrapFewShotWithOptuna.compile" + } + ) + labeled_few_shot_compile_settings = base.model_copy( + update={"name": base.name or "dspy.teleprompt.LabeledFewShot.compile"} + ) + + patched_functions = [ + SymbolPatcher( + get_base_symbol=lambda: importlib.import_module("dspy"), + attribute_name="Predict.__call__", + make_new_value=dspy_wrapper(predict_call_settings), + ), + SymbolPatcher( + get_base_symbol=lambda: importlib.import_module("dspy"), + attribute_name="Predict.forward", + make_new_value=dspy_wrapper(predict_forward_settings), + ), + SymbolPatcher( + get_base_symbol=lambda: importlib.import_module("dspy"), + attribute_name="TypedPredictor.__call__", + make_new_value=dspy_wrapper(typed_predictor_call_settings), + ), + SymbolPatcher( + get_base_symbol=lambda: importlib.import_module("dspy"), + attribute_name="TypedPredictor.forward", + make_new_value=dspy_wrapper(typed_predictor_forward_settings), + ), + SymbolPatcher( + get_base_symbol=lambda: importlib.import_module("dspy"), + attribute_name="Module.__call__", + make_new_value=dspy_wrapper(module_call_settings), + ), + SymbolPatcher( + get_base_symbol=lambda: importlib.import_module("dspy"), + attribute_name="TypedChainOfThought.__call__", + make_new_value=dspy_wrapper(typed_chain_of_thought_call_settings), + ), + SymbolPatcher( + get_base_symbol=lambda: importlib.import_module("dspy"), + attribute_name="Retrieve.__call__", + make_new_value=dspy_wrapper(retrieve_call_settings), + ), + SymbolPatcher( + get_base_symbol=lambda: importlib.import_module("dspy"), + attribute_name="Retrieve.forward", + make_new_value=dspy_wrapper(retrieve_forward_settings), + ), + SymbolPatcher( + get_base_symbol=lambda: importlib.import_module("dspy.evaluate.evaluate"), + attribute_name="Evaluate.__call__", + make_new_value=dspy_wrapper(evaluate_call_settings), + ), + SymbolPatcher( + get_base_symbol=lambda: importlib.import_module("dspy.teleprompt"), + attribute_name="BootstrapFewShot.compile", + make_new_value=dspy_wrapper(bootstrap_few_shot_compile_settings), + ), + SymbolPatcher( + get_base_symbol=lambda: importlib.import_module("dspy.teleprompt"), + attribute_name="COPRO.compile", + make_new_value=dspy_wrapper(copro_compile_settings), + ), + SymbolPatcher( + get_base_symbol=lambda: importlib.import_module("dspy.teleprompt"), + attribute_name="Ensemble.compile", + make_new_value=dspy_wrapper(ensemble_compile_settings), + ), + SymbolPatcher( + get_base_symbol=lambda: importlib.import_module("dspy.teleprompt"), + attribute_name="BootstrapFinetune.compile", + make_new_value=dspy_wrapper(bootstrap_finetune_compile_settings), + ), + SymbolPatcher( + get_base_symbol=lambda: importlib.import_module("dspy.teleprompt"), + attribute_name="KNNFewShot.compile", + make_new_value=dspy_wrapper(knn_few_shot_compile_settings), + ), + SymbolPatcher( + get_base_symbol=lambda: importlib.import_module("dspy.teleprompt"), + attribute_name="MIPRO.compile", + make_new_value=dspy_wrapper(mipro_compile_settings), + ), + SymbolPatcher( + get_base_symbol=lambda: importlib.import_module("dspy.teleprompt"), + attribute_name="BootstrapFewShotWithRandomSearch.compile", + make_new_value=dspy_wrapper( + bootstrap_few_shot_with_random_search_compile_settings + ), + ), + SymbolPatcher( + get_base_symbol=lambda: importlib.import_module("dspy.teleprompt"), + attribute_name="SignatureOptimizer.compile", + make_new_value=dspy_wrapper(signature_optimizer_compile_settings), + ), + SymbolPatcher( + get_base_symbol=lambda: importlib.import_module("dspy.teleprompt"), + attribute_name="BayesianSignatureOptimizer.compile", + make_new_value=dspy_wrapper(bayesian_signature_optimizer_compile_settings), + ), + SymbolPatcher( + get_base_symbol=lambda: importlib.import_module( + "dspy.teleprompt.signature_opt_typed" + ), + attribute_name="optimize_signature", + make_new_value=dspy_wrapper( + signature_opt_typed_optimize_signature_settings + ), + ), + SymbolPatcher( + get_base_symbol=lambda: importlib.import_module("dspy.teleprompt"), + attribute_name="BootstrapFewShotWithOptuna.compile", + make_new_value=dspy_wrapper( + bootstrap_few_shot_with_optuna_compile_settings + ), + ), + SymbolPatcher( + get_base_symbol=lambda: importlib.import_module("dspy.teleprompt"), + attribute_name="LabeledFewShot.compile", + make_new_value=dspy_wrapper(labeled_few_shot_compile_settings), + ), + ] + + # Patch LM classes + patched_functions += dspy_get_patched_lm_functions( + base_symbol="dspy", lm_class_name="AzureOpenAI", settings=base + ) + patched_functions += dspy_get_patched_lm_functions( + base_symbol="dspy", lm_class_name="OpenAI", settings=base + ) + patched_functions += dspy_get_patched_lm_functions( + base_symbol="dspy", lm_class_name="Cohere", settings=base + ) + patched_functions += dspy_get_patched_lm_functions( + base_symbol="dspy", lm_class_name="Clarifai", settings=base + ) + patched_functions += dspy_get_patched_lm_functions( + base_symbol="dspy", lm_class_name="Google", settings=base + ) + patched_functions += dspy_get_patched_lm_functions( + base_symbol="dspy", lm_class_name="HFClientTGI", settings=base + ) + patched_functions += dspy_get_patched_lm_functions( + base_symbol="dspy", lm_class_name="HFClientVLLM", settings=base + ) + patched_functions += dspy_get_patched_lm_functions( + base_symbol="dspy", lm_class_name="Anyscale", settings=base + ) + patched_functions += dspy_get_patched_lm_functions( + base_symbol="dspy", lm_class_name="Together", settings=base + ) + patched_functions += dspy_get_patched_lm_functions( + base_symbol="dspy", lm_class_name="OllamaLocal", settings=base + ) + + databricks_basic_request_settings = base.model_copy( + update={"name": base.name or "dspy.Databricks.basic_request"} + ) + databricks_call_settings = base.model_copy( + update={"name": base.name or "dspy.Databricks"} + ) + colbertv2_call_settings = base.model_copy( + update={"name": base.name or "dspy.ColBERTv2"} + ) + pyserini_call_settings = base.model_copy( + update={"name": base.name or "dspy.Pyserini"} + ) + + patched_functions += [ + SymbolPatcher( + get_base_symbol=lambda: importlib.import_module("dspy"), + attribute_name="Databricks.basic_request", + make_new_value=dspy_wrapper(databricks_basic_request_settings), + ), + SymbolPatcher( + get_base_symbol=lambda: importlib.import_module("dspy"), + attribute_name="Databricks.__call__", + make_new_value=dspy_wrapper(databricks_call_settings), + ), + SymbolPatcher( + get_base_symbol=lambda: importlib.import_module("dspy"), + attribute_name="ColBERTv2.__call__", + make_new_value=dspy_wrapper(colbertv2_call_settings), + ), + SymbolPatcher( + get_base_symbol=lambda: importlib.import_module("dspy"), + attribute_name="Pyserini.__call__", + make_new_value=dspy_wrapper(pyserini_call_settings), ), ] + _dspy_patcher = MultiPatcher(patched_functions) -patched_functions = [ - SymbolPatcher( - get_base_symbol=lambda: importlib.import_module("dspy"), - attribute_name="Predict.__call__", - make_new_value=dspy_wrapper("dspy.Predict"), - ), - SymbolPatcher( - get_base_symbol=lambda: importlib.import_module("dspy"), - attribute_name="Predict.forward", - make_new_value=dspy_wrapper("dspy.Predict.forward"), - ), - SymbolPatcher( - get_base_symbol=lambda: importlib.import_module("dspy"), - attribute_name="TypedPredictor.__call__", - make_new_value=dspy_wrapper("dspy.TypedPredictor"), - ), - SymbolPatcher( - get_base_symbol=lambda: importlib.import_module("dspy"), - attribute_name="TypedPredictor.forward", - make_new_value=dspy_wrapper("dspy.TypedPredictor.forward"), - ), - SymbolPatcher( - get_base_symbol=lambda: importlib.import_module("dspy"), - attribute_name="Module.__call__", - make_new_value=dspy_wrapper("dspy.Module"), - ), - SymbolPatcher( - get_base_symbol=lambda: importlib.import_module("dspy"), - attribute_name="TypedChainOfThought.__call__", - make_new_value=dspy_wrapper("dspy.TypedChainOfThought"), - ), - SymbolPatcher( - get_base_symbol=lambda: importlib.import_module("dspy"), - attribute_name="Retrieve.__call__", - make_new_value=dspy_wrapper("dspy.Retrieve"), - ), - SymbolPatcher( - get_base_symbol=lambda: importlib.import_module("dspy"), - attribute_name="Retrieve.forward", - make_new_value=dspy_wrapper("dspy.Retrieve.forward"), - ), - SymbolPatcher( - get_base_symbol=lambda: importlib.import_module("dspy.evaluate.evaluate"), - attribute_name="Evaluate.__call__", - make_new_value=dspy_wrapper("dspy.evaluate.Evaluate"), - ), - SymbolPatcher( - get_base_symbol=lambda: importlib.import_module("dspy.teleprompt"), - attribute_name="BootstrapFewShot.compile", - make_new_value=dspy_wrapper("dspy.teleprompt.BootstrapFewShot.compile"), - ), - SymbolPatcher( - get_base_symbol=lambda: importlib.import_module("dspy.teleprompt"), - attribute_name="COPRO.compile", - make_new_value=dspy_wrapper("dspy.teleprompt.COPRO.compile"), - ), - SymbolPatcher( - get_base_symbol=lambda: importlib.import_module("dspy.teleprompt"), - attribute_name="Ensemble.compile", - make_new_value=dspy_wrapper("dspy.teleprompt.Ensemble.compile"), - ), - SymbolPatcher( - get_base_symbol=lambda: importlib.import_module("dspy.teleprompt"), - attribute_name="BootstrapFinetune.compile", - make_new_value=dspy_wrapper("dspy.teleprompt.BootstrapFinetune.compile"), - ), - SymbolPatcher( - get_base_symbol=lambda: importlib.import_module("dspy.teleprompt"), - attribute_name="KNNFewShot.compile", - make_new_value=dspy_wrapper("dspy.teleprompt.KNNFewShot.compile"), - ), - SymbolPatcher( - get_base_symbol=lambda: importlib.import_module("dspy.teleprompt"), - attribute_name="MIPRO.compile", - make_new_value=dspy_wrapper("dspy.teleprompt.MIPRO.compile"), - ), - SymbolPatcher( - get_base_symbol=lambda: importlib.import_module("dspy.teleprompt"), - attribute_name="BootstrapFewShotWithRandomSearch.compile", - make_new_value=dspy_wrapper( - "dspy.teleprompt.BootstrapFewShotWithRandomSearch.compile" - ), - ), - SymbolPatcher( - get_base_symbol=lambda: importlib.import_module("dspy.teleprompt"), - attribute_name="SignatureOptimizer.compile", - make_new_value=dspy_wrapper("dspy.teleprompt.SignatureOptimizer.compile"), - ), - SymbolPatcher( - get_base_symbol=lambda: importlib.import_module("dspy.teleprompt"), - attribute_name="BayesianSignatureOptimizer.compile", - make_new_value=dspy_wrapper( - "dspy.teleprompt.BayesianSignatureOptimizer.compile" - ), - ), - SymbolPatcher( - get_base_symbol=lambda: importlib.import_module( - "dspy.teleprompt.signature_opt_typed" - ), - attribute_name="optimize_signature", - make_new_value=dspy_wrapper( - "dspy.teleprompt.signature_opt_typed.optimize_signature" - ), - ), - SymbolPatcher( - get_base_symbol=lambda: importlib.import_module("dspy.teleprompt"), - attribute_name="BootstrapFewShotWithOptuna.compile", - make_new_value=dspy_wrapper( - "dspy.teleprompt.BootstrapFewShotWithOptuna.compile" - ), - ), - SymbolPatcher( - get_base_symbol=lambda: importlib.import_module("dspy.teleprompt"), - attribute_name="LabeledFewShot.compile", - make_new_value=dspy_wrapper("dspy.teleprompt.LabeledFewShot.compile"), - ), -] - -# Patch LM classes -patched_functions += dspy_get_patched_lm_functions( - base_symbol="dspy", lm_class_name="AzureOpenAI" -) -patched_functions += dspy_get_patched_lm_functions( - base_symbol="dspy", lm_class_name="OpenAI" -) -patched_functions += dspy_get_patched_lm_functions( - base_symbol="dspy", lm_class_name="Cohere" -) -patched_functions += dspy_get_patched_lm_functions( - base_symbol="dspy", lm_class_name="Clarifai" -) -patched_functions += dspy_get_patched_lm_functions( - base_symbol="dspy", lm_class_name="Google" -) -patched_functions += dspy_get_patched_lm_functions( - base_symbol="dspy", lm_class_name="HFClientTGI" -) -patched_functions += dspy_get_patched_lm_functions( - base_symbol="dspy", lm_class_name="HFClientVLLM" -) -patched_functions += dspy_get_patched_lm_functions( - base_symbol="dspy", lm_class_name="Anyscale" -) -patched_functions += dspy_get_patched_lm_functions( - base_symbol="dspy", lm_class_name="Together" -) -patched_functions += dspy_get_patched_lm_functions( - base_symbol="dspy", lm_class_name="OllamaLocal" -) -patched_functions += [ - SymbolPatcher( - get_base_symbol=lambda: importlib.import_module("dspy"), - attribute_name="Databricks.basic_request", - make_new_value=dspy_wrapper("dspy.Databricks.basic_request"), - ), - SymbolPatcher( - get_base_symbol=lambda: importlib.import_module("dspy"), - attribute_name="Databricks.__call__", - make_new_value=dspy_wrapper("dspy.Databricks"), - ), - SymbolPatcher( - get_base_symbol=lambda: importlib.import_module("dspy"), - attribute_name="ColBERTv2.__call__", - make_new_value=dspy_wrapper("dspy.ColBERTv2"), - ), - SymbolPatcher( - get_base_symbol=lambda: importlib.import_module("dspy"), - attribute_name="Pyserini.__call__", - make_new_value=dspy_wrapper("dspy.Pyserini"), - ), -] - -dspy_patcher = MultiPatcher(patched_functions) + return _dspy_patcher diff --git a/weave/integrations/google_ai_studio/google_ai_studio_sdk.py b/weave/integrations/google_ai_studio/google_ai_studio_sdk.py index 2cd8a2fe1377..7c4e6d2a7406 100644 --- a/weave/integrations/google_ai_studio/google_ai_studio_sdk.py +++ b/weave/integrations/google_ai_studio/google_ai_studio_sdk.py @@ -1,16 +1,21 @@ +from __future__ import annotations + import importlib from functools import wraps -from typing import TYPE_CHECKING, Any, Callable, Optional +from typing import TYPE_CHECKING, Any, Callable import weave +from weave.trace.autopatch import IntegrationSettings, OpSettings from weave.trace.op_extensions.accumulator import add_accumulator -from weave.trace.patcher import MultiPatcher, SymbolPatcher +from weave.trace.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher from weave.trace.serialize import dictify from weave.trace.weave_client import Call if TYPE_CHECKING: from google.generativeai.types.generation_types import GenerateContentResponse +_google_genai_patcher: MultiPatcher | None = None + def gemini_postprocess_inputs(inputs: dict[str, Any]) -> dict[str, Any]: if "self" in inputs: @@ -19,8 +24,8 @@ def gemini_postprocess_inputs(inputs: dict[str, Any]) -> dict[str, Any]: def gemini_accumulator( - acc: Optional["GenerateContentResponse"], value: "GenerateContentResponse" -) -> "GenerateContentResponse": + acc: GenerateContentResponse | None, value: GenerateContentResponse +) -> GenerateContentResponse: if acc is None: return value @@ -64,9 +69,7 @@ def gemini_accumulator( return acc -def gemini_on_finish( - call: Call, output: Any, exception: Optional[BaseException] -) -> None: +def gemini_on_finish(call: Call, output: Any, exception: BaseException | None) -> None: if "model_name" in call.inputs["self"]: original_model_name = call.inputs["self"]["model_name"] elif "model" in call.inputs["self"]: @@ -89,10 +92,13 @@ def gemini_on_finish( call.summary.update(summary_update) -def gemini_wrapper_sync(name: str) -> Callable[[Callable], Callable]: +def gemini_wrapper_sync(settings: OpSettings) -> Callable[[Callable], Callable]: def wrapper(fn: Callable) -> Callable: - op = weave.op(postprocess_inputs=gemini_postprocess_inputs)(fn) - op.name = name # type: ignore + op_kwargs = settings.model_dump() + if not op_kwargs.get("postprocess_inputs"): + op_kwargs["postprocess_inputs"] = gemini_postprocess_inputs + + op = weave.op(fn, **op_kwargs) op._set_on_finish_handler(gemini_on_finish) return add_accumulator( op, # type: ignore @@ -104,7 +110,7 @@ def wrapper(fn: Callable) -> Callable: return wrapper -def gemini_wrapper_async(name: str) -> Callable[[Callable], Callable]: +def gemini_wrapper_async(settings: OpSettings) -> Callable[[Callable], Callable]: def wrapper(fn: Callable) -> Callable: def _fn_wrapper(fn: Callable) -> Callable: @wraps(fn) @@ -113,9 +119,11 @@ async def _async_wrapper(*args: Any, **kwargs: Any) -> Any: return _async_wrapper - "We need to do this so we can check if `stream` is used" - op = weave.op(postprocess_inputs=gemini_postprocess_inputs)(_fn_wrapper(fn)) - op.name = name # type: ignore + op_kwargs = settings.model_dump() + if not op_kwargs.get("postprocess_inputs"): + op_kwargs["postprocess_inputs"] = gemini_postprocess_inputs + + op = weave.op(_fn_wrapper(fn), **op_kwargs) op._set_on_finish_handler(gemini_on_finish) return add_accumulator( op, # type: ignore @@ -127,33 +135,72 @@ async def _async_wrapper(*args: Any, **kwargs: Any) -> Any: return wrapper -google_genai_patcher = MultiPatcher( - [ - SymbolPatcher( - lambda: importlib.import_module("google.generativeai.generative_models"), - "GenerativeModel.generate_content", - gemini_wrapper_sync( - name="google.generativeai.GenerativeModel.generate_content" +def get_google_genai_patcher( + settings: IntegrationSettings | None = None, +) -> MultiPatcher | NoOpPatcher: + if settings is None: + settings = IntegrationSettings() + + if not settings.enabled: + return NoOpPatcher() + + global _google_genai_patcher + if _google_genai_patcher is not None: + return _google_genai_patcher + + base = settings.op_settings + + generate_content_settings = base.model_copy( + update={ + "name": base.name or "google.generativeai.GenerativeModel.generate_content" + } + ) + generate_content_async_settings = base.model_copy( + update={ + "name": base.name + or "google.generativeai.GenerativeModel.generate_content_async" + } + ) + send_message_settings = base.model_copy( + update={"name": base.name or "google.generativeai.ChatSession.send_message"} + ) + send_message_async_settings = base.model_copy( + update={ + "name": base.name or "google.generativeai.ChatSession.send_message_async" + } + ) + + _google_genai_patcher = MultiPatcher( + [ + SymbolPatcher( + lambda: importlib.import_module( + "google.generativeai.generative_models" + ), + "GenerativeModel.generate_content", + gemini_wrapper_sync(generate_content_settings), + ), + SymbolPatcher( + lambda: importlib.import_module( + "google.generativeai.generative_models" + ), + "GenerativeModel.generate_content_async", + gemini_wrapper_async(generate_content_async_settings), ), - ), - SymbolPatcher( - lambda: importlib.import_module("google.generativeai.generative_models"), - "GenerativeModel.generate_content_async", - gemini_wrapper_async( - name="google.generativeai.GenerativeModel.generate_content_async" + SymbolPatcher( + lambda: importlib.import_module( + "google.generativeai.generative_models" + ), + "ChatSession.send_message", + gemini_wrapper_sync(send_message_settings), ), - ), - SymbolPatcher( - lambda: importlib.import_module("google.generativeai.generative_models"), - "ChatSession.send_message", - gemini_wrapper_sync(name="google.generativeai.ChatSession.send_message"), - ), - SymbolPatcher( - lambda: importlib.import_module("google.generativeai.generative_models"), - "ChatSession.send_message_async", - gemini_wrapper_async( - name="google.generativeai.ChatSession.send_message_async" + SymbolPatcher( + lambda: importlib.import_module( + "google.generativeai.generative_models" + ), + "ChatSession.send_message_async", + gemini_wrapper_async(send_message_async_settings), ), - ), - ] -) + ] + ) + + return _google_genai_patcher diff --git a/weave/integrations/groq/groq_sdk.py b/weave/integrations/groq/groq_sdk.py index 4f470e6d7437..c5c07fd705f7 100644 --- a/weave/integrations/groq/groq_sdk.py +++ b/weave/integrations/groq/groq_sdk.py @@ -1,17 +1,23 @@ +from __future__ import annotations + import importlib -from typing import TYPE_CHECKING, Callable, Optional +from typing import TYPE_CHECKING, Callable + +import weave +from weave.trace.autopatch import IntegrationSettings, OpSettings +from weave.trace.op_extensions.accumulator import add_accumulator +from weave.trace.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher if TYPE_CHECKING: from groq.types.chat import ChatCompletion, ChatCompletionChunk -import weave -from weave.trace.op_extensions.accumulator import add_accumulator -from weave.trace.patcher import MultiPatcher, SymbolPatcher + +_groq_patcher: MultiPatcher | None = None def groq_accumulator( - acc: Optional["ChatCompletion"], value: "ChatCompletionChunk" -) -> "ChatCompletion": + acc: ChatCompletion | None, value: ChatCompletionChunk +) -> ChatCompletion: from groq.types.chat import ChatCompletion, ChatCompletionMessage from groq.types.chat.chat_completion import Choice from groq.types.chat.chat_completion_chunk import Choice as ChoiceChunk @@ -83,11 +89,10 @@ def should_use_accumulator(inputs: dict) -> bool: return isinstance(inputs, dict) and bool(inputs.get("stream")) -def groq_wrapper(name: str) -> Callable[[Callable], Callable]: +def groq_wrapper(settings: OpSettings) -> Callable[[Callable], Callable]: def wrapper(fn: Callable) -> Callable: - op = weave.op()(fn) - op.name = name # type: ignore - # return op + op_kwargs = settings.model_dump() + op = weave.op(fn, **op_kwargs) return add_accumulator( op, # type: ignore make_accumulator=lambda inputs: groq_accumulator, @@ -97,17 +102,41 @@ def wrapper(fn: Callable) -> Callable: return wrapper -groq_patcher = MultiPatcher( - [ - SymbolPatcher( - lambda: importlib.import_module("groq.resources.chat.completions"), - "Completions.create", - groq_wrapper(name="groq.chat.completions.create"), - ), - SymbolPatcher( - lambda: importlib.import_module("groq.resources.chat.completions"), - "AsyncCompletions.create", - groq_wrapper(name="groq.async.chat.completions.create"), - ), - ] -) +def get_groq_patcher( + settings: IntegrationSettings | None = None, +) -> MultiPatcher | NoOpPatcher: + if settings is None: + settings = IntegrationSettings() + + if not settings.enabled: + return NoOpPatcher() + + global _groq_patcher + if _groq_patcher is not None: + return _groq_patcher + + base = settings.op_settings + + chat_completions_settings = base.model_copy( + update={"name": base.name or "groq.chat.completions.create"} + ) + async_chat_completions_settings = base.model_copy( + update={"name": base.name or "groq.async.chat.completions.create"} + ) + + _groq_patcher = MultiPatcher( + [ + SymbolPatcher( + lambda: importlib.import_module("groq.resources.chat.completions"), + "Completions.create", + groq_wrapper(chat_completions_settings), + ), + SymbolPatcher( + lambda: importlib.import_module("groq.resources.chat.completions"), + "AsyncCompletions.create", + groq_wrapper(async_chat_completions_settings), + ), + ] + ) + + return _groq_patcher diff --git a/weave/integrations/instructor/instructor_iterable_utils.py b/weave/integrations/instructor/instructor_iterable_utils.py index 84d64a103b6b..3b0f128a1320 100644 --- a/weave/integrations/instructor/instructor_iterable_utils.py +++ b/weave/integrations/instructor/instructor_iterable_utils.py @@ -4,6 +4,7 @@ from pydantic import BaseModel import weave +from weave.trace.autopatch import OpSettings from weave.trace.op_extensions.accumulator import add_accumulator @@ -27,10 +28,10 @@ def should_accumulate_iterable(inputs: dict) -> bool: return False -def instructor_wrapper_sync(name: str) -> Callable[[Callable], Callable]: +def instructor_wrapper_sync(settings: OpSettings) -> Callable[[Callable], Callable]: def wrapper(fn: Callable) -> Callable: - op = weave.op(fn) - op.name = name # type: ignore + op_kwargs = settings.model_dump() + op = weave.op(fn, **op_kwargs) return add_accumulator( op, # type: ignore make_accumulator=lambda inputs: instructor_iterable_accumulator, @@ -40,7 +41,7 @@ def wrapper(fn: Callable) -> Callable: return wrapper -def instructor_wrapper_async(name: str) -> Callable[[Callable], Callable]: +def instructor_wrapper_async(settings: OpSettings) -> Callable[[Callable], Callable]: def wrapper(fn: Callable) -> Callable: def _fn_wrapper(fn: Callable) -> Callable: @wraps(fn) @@ -49,9 +50,8 @@ async def _async_wrapper(*args: Any, **kwargs: Any) -> Any: return _async_wrapper - "We need to do this so we can check if `stream` is used" - op = weave.op(_fn_wrapper(fn)) - op.name = name # type: ignore + op_kwargs = settings.model_dump() + op = weave.op(_fn_wrapper(fn), **op_kwargs) return add_accumulator( op, # type: ignore make_accumulator=lambda inputs: instructor_iterable_accumulator, diff --git a/weave/integrations/instructor/instructor_partial_utils.py b/weave/integrations/instructor/instructor_partial_utils.py index f90dc7edb17b..8efa84b302f8 100644 --- a/weave/integrations/instructor/instructor_partial_utils.py +++ b/weave/integrations/instructor/instructor_partial_utils.py @@ -3,6 +3,7 @@ from pydantic import BaseModel import weave +from weave.trace.autopatch import OpSettings from weave.trace.op_extensions.accumulator import add_accumulator @@ -14,10 +15,10 @@ def instructor_partial_accumulator( return acc -def instructor_wrapper_partial(name: str) -> Callable[[Callable], Callable]: +def instructor_wrapper_partial(settings: OpSettings) -> Callable[[Callable], Callable]: def wrapper(fn: Callable) -> Callable: - op = weave.op()(fn) - op.name = name # type: ignore + op_kwargs = settings.model_dump() + op = weave.op(fn, **op_kwargs) return add_accumulator( op, # type: ignore make_accumulator=lambda inputs: instructor_partial_accumulator, diff --git a/weave/integrations/instructor/instructor_sdk.py b/weave/integrations/instructor/instructor_sdk.py index 867e9f2a785e..00dfde029c9c 100644 --- a/weave/integrations/instructor/instructor_sdk.py +++ b/weave/integrations/instructor/instructor_sdk.py @@ -1,31 +1,65 @@ +from __future__ import annotations + import importlib -from weave.trace.patcher import MultiPatcher, SymbolPatcher +from weave.trace.autopatch import IntegrationSettings +from weave.trace.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher from .instructor_iterable_utils import instructor_wrapper_async, instructor_wrapper_sync from .instructor_partial_utils import instructor_wrapper_partial -instructor_patcher = MultiPatcher( - [ - SymbolPatcher( - lambda: importlib.import_module("instructor.client"), - "Instructor.create", - instructor_wrapper_sync(name="Instructor.create"), - ), - SymbolPatcher( - lambda: importlib.import_module("instructor.client"), - "AsyncInstructor.create", - instructor_wrapper_async(name="AsyncInstructor.create"), - ), - SymbolPatcher( - lambda: importlib.import_module("instructor.client"), - "Instructor.create_partial", - instructor_wrapper_partial(name="Instructor.create_partial"), - ), - SymbolPatcher( - lambda: importlib.import_module("instructor.client"), - "AsyncInstructor.create_partial", - instructor_wrapper_partial(name="AsyncInstructor.create_partial"), - ), - ] -) +_instructor_patcher: MultiPatcher | None = None + + +def get_instructor_patcher( + settings: IntegrationSettings | None = None, +) -> MultiPatcher | NoOpPatcher: + if settings is None: + settings = IntegrationSettings() + + if not settings.enabled: + return NoOpPatcher() + + global _instructor_patcher + if _instructor_patcher is not None: + return _instructor_patcher + + base = settings.op_settings + + create_settings = base.model_copy(update={"name": base.name or "Instructor.create"}) + async_create_settings = base.model_copy( + update={"name": base.name or "AsyncInstructor.create"} + ) + create_partial_settings = base.model_copy( + update={"name": base.name or "Instructor.create_partial"} + ) + async_create_partial_settings = base.model_copy( + update={"name": base.name or "AsyncInstructor.create_partial"} + ) + + _instructor_patcher = MultiPatcher( + [ + SymbolPatcher( + lambda: importlib.import_module("instructor.client"), + "Instructor.create", + instructor_wrapper_sync(create_settings), + ), + SymbolPatcher( + lambda: importlib.import_module("instructor.client"), + "AsyncInstructor.create", + instructor_wrapper_async(async_create_settings), + ), + SymbolPatcher( + lambda: importlib.import_module("instructor.client"), + "Instructor.create_partial", + instructor_wrapper_partial(create_partial_settings), + ), + SymbolPatcher( + lambda: importlib.import_module("instructor.client"), + "AsyncInstructor.create_partial", + instructor_wrapper_partial(async_create_partial_settings), + ), + ] + ) + + return _instructor_patcher diff --git a/weave/integrations/litellm/litellm.py b/weave/integrations/litellm/litellm.py index c3bf1bf114a2..9ae6e492c84e 100644 --- a/weave/integrations/litellm/litellm.py +++ b/weave/integrations/litellm/litellm.py @@ -1,19 +1,24 @@ +from __future__ import annotations + import importlib -from typing import TYPE_CHECKING, Any, Callable, Optional +from typing import TYPE_CHECKING, Any, Callable import weave +from weave.trace.autopatch import IntegrationSettings, OpSettings from weave.trace.op_extensions.accumulator import add_accumulator -from weave.trace.patcher import MultiPatcher, SymbolPatcher +from weave.trace.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher if TYPE_CHECKING: from litellm.utils import ModelResponse +_litellm_patcher: MultiPatcher | None = None + # This accumulator is nearly identical to the mistral accumulator, just with different types. def litellm_accumulator( - acc: Optional["ModelResponse"], - value: "ModelResponse", -) -> "ModelResponse": + acc: ModelResponse | None, + value: ModelResponse, +) -> ModelResponse: # This import should be safe at this point from litellm.utils import Choices, Message, ModelResponse, Usage @@ -82,10 +87,10 @@ def should_use_accumulator(inputs: dict) -> bool: return isinstance(inputs, dict) and bool(inputs.get("stream")) -def make_wrapper(name: str) -> Callable: +def make_wrapper(settings: OpSettings) -> Callable: def litellm_wrapper(fn: Callable) -> Callable: - op = weave.op()(fn) - op.name = name # type: ignore + op_kwargs = settings.model_dump() + op = weave.op(fn, **op_kwargs) return add_accumulator( op, # type: ignore make_accumulator=lambda inputs: litellm_accumulator, @@ -96,17 +101,41 @@ def litellm_wrapper(fn: Callable) -> Callable: return litellm_wrapper -litellm_patcher = MultiPatcher( - [ - SymbolPatcher( - lambda: importlib.import_module("litellm"), - "completion", - make_wrapper("litellm.completion"), - ), - SymbolPatcher( - lambda: importlib.import_module("litellm"), - "acompletion", - make_wrapper("litellm.acompletion"), - ), - ] -) +def get_litellm_patcher( + settings: IntegrationSettings | None = None, +) -> MultiPatcher | NoOpPatcher: + if settings is None: + settings = IntegrationSettings() + + if not settings.enabled: + return NoOpPatcher() + + global _litellm_patcher + if _litellm_patcher is not None: + return _litellm_patcher + + base = settings.op_settings + + completion_settings = base.model_copy( + update={"name": base.name or "litellm.completion"} + ) + acompletion_settings = base.model_copy( + update={"name": base.name or "litellm.acompletion"} + ) + + _litellm_patcher = MultiPatcher( + [ + SymbolPatcher( + lambda: importlib.import_module("litellm"), + "completion", + make_wrapper(completion_settings), + ), + SymbolPatcher( + lambda: importlib.import_module("litellm"), + "acompletion", + make_wrapper(acompletion_settings), + ), + ] + ) + + return _litellm_patcher diff --git a/weave/integrations/mistral/__init__.py b/weave/integrations/mistral/__init__.py index 34b40835efcf..d78812d9afa8 100644 --- a/weave/integrations/mistral/__init__.py +++ b/weave/integrations/mistral/__init__.py @@ -8,10 +8,10 @@ mistral_version = "1.0" # we need to return a patching function if version.parse(mistral_version) < version.parse("1.0.0"): - from .v0.mistral import mistral_patcher + from .v0.mistral import get_mistral_patcher # noqa: F401 print( f"Using MistralAI version {mistral_version}. Please consider upgrading to version 1.0.0 or later." ) else: - from .v1.mistral import mistral_patcher # noqa: F401 + from .v1.mistral import get_mistral_patcher # noqa: F401 diff --git a/weave/integrations/mistral/v0/mistral.py b/weave/integrations/mistral/v0/mistral.py index 6d4915eab41b..70a3fa183bb8 100644 --- a/weave/integrations/mistral/v0/mistral.py +++ b/weave/integrations/mistral/v0/mistral.py @@ -1,9 +1,12 @@ +from __future__ import annotations + import importlib -from typing import TYPE_CHECKING, Callable, Optional +from typing import TYPE_CHECKING, Callable import weave +from weave.trace.autopatch import IntegrationSettings, OpSettings from weave.trace.op_extensions.accumulator import add_accumulator -from weave.trace.patcher import MultiPatcher, SymbolPatcher +from weave.trace.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher if TYPE_CHECKING: from mistralai.models.chat_completion import ( @@ -11,11 +14,13 @@ ChatCompletionStreamResponse, ) +_mistral_patcher: MultiPatcher | None = None + def mistral_accumulator( - acc: Optional["ChatCompletionResponse"], - value: "ChatCompletionStreamResponse", -) -> "ChatCompletionResponse": + acc: ChatCompletionResponse | None, + value: ChatCompletionStreamResponse, +) -> ChatCompletionResponse: # This import should be safe at this point from mistralai.models.chat_completion import ( ChatCompletionResponse, @@ -72,37 +77,78 @@ def mistral_accumulator( return acc -def mistral_stream_wrapper(fn: Callable) -> Callable: - op = weave.op()(fn) - acc_op = add_accumulator(op, lambda inputs: mistral_accumulator) # type: ignore - return acc_op - - -mistral_patcher = MultiPatcher( - [ - # Patch the sync, non-streaming chat method - SymbolPatcher( - lambda: importlib.import_module("mistralai.client"), - "MistralClient.chat", - weave.op(), - ), - # Patch the sync, streaming chat method - SymbolPatcher( - lambda: importlib.import_module("mistralai.client"), - "MistralClient.chat_stream", - mistral_stream_wrapper, - ), - # Patch the async, non-streaming chat method - SymbolPatcher( - lambda: importlib.import_module("mistralai.async_client"), - "MistralAsyncClient.chat", - weave.op(), - ), - # Patch the async, streaming chat method - SymbolPatcher( - lambda: importlib.import_module("mistralai.async_client"), - "MistralAsyncClient.chat_stream", - mistral_stream_wrapper, - ), - ] -) +def mistral_stream_wrapper(settings: OpSettings) -> Callable: + def wrapper(fn: Callable) -> Callable: + op_kwargs = settings.model_dump() + op = weave.op(fn, **op_kwargs) + acc_op = add_accumulator(op, lambda inputs: mistral_accumulator) # type: ignore + return acc_op + + return wrapper + + +def mistral_wrapper(settings: OpSettings) -> Callable: + def wrapper(fn: Callable) -> Callable: + op_kwargs = settings.model_dump() + op = weave.op(fn, **op_kwargs) + return op + + return wrapper + + +def get_mistral_patcher( + settings: IntegrationSettings | None = None, +) -> MultiPatcher | NoOpPatcher: + if settings is None: + settings = IntegrationSettings() + + if not settings.enabled: + return NoOpPatcher() + + global _mistral_patcher + if _mistral_patcher is not None: + return _mistral_patcher + + base = settings.op_settings + + chat_settings = base.model_copy(update={"name": base.name or "mistralai.chat"}) + chat_stream_settings = base.model_copy( + update={"name": base.name or "mistralai.chat_stream"} + ) + async_chat_settings = base.model_copy( + update={"name": base.name or "mistralai.async_client.chat"} + ) + async_chat_stream_settings = base.model_copy( + update={"name": base.name or "mistralai.async_client.chat_stream"} + ) + + _mistral_patcher = MultiPatcher( + [ + # Patch the sync, non-streaming chat method + SymbolPatcher( + lambda: importlib.import_module("mistralai.client"), + "MistralClient.chat", + mistral_wrapper(chat_settings), + ), + # Patch the sync, streaming chat method + SymbolPatcher( + lambda: importlib.import_module("mistralai.client"), + "MistralClient.chat_stream", + mistral_stream_wrapper(chat_stream_settings), + ), + # Patch the async, non-streaming chat method + SymbolPatcher( + lambda: importlib.import_module("mistralai.async_client"), + "MistralAsyncClient.chat", + mistral_wrapper(async_chat_settings), + ), + # Patch the async, streaming chat method + SymbolPatcher( + lambda: importlib.import_module("mistralai.async_client"), + "MistralAsyncClient.chat_stream", + mistral_stream_wrapper(async_chat_stream_settings), + ), + ] + ) + + return _mistral_patcher diff --git a/weave/integrations/mistral/v1/mistral.py b/weave/integrations/mistral/v1/mistral.py index 692aa7b159bf..d52d42af3c4f 100644 --- a/weave/integrations/mistral/v1/mistral.py +++ b/weave/integrations/mistral/v1/mistral.py @@ -1,9 +1,12 @@ +from __future__ import annotations + import importlib -from typing import TYPE_CHECKING, Callable, Optional +from typing import TYPE_CHECKING, Callable import weave +from weave.trace.autopatch import IntegrationSettings, OpSettings from weave.trace.op_extensions.accumulator import add_accumulator -from weave.trace.patcher import MultiPatcher, SymbolPatcher +from weave.trace.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher if TYPE_CHECKING: from mistralai.models import ( @@ -11,11 +14,13 @@ CompletionEvent, ) +_mistral_patcher: MultiPatcher | None = None + def mistral_accumulator( - acc: Optional["ChatCompletionResponse"], - value: "CompletionEvent", -) -> "ChatCompletionResponse": + acc: ChatCompletionResponse | None, + value: CompletionEvent, +) -> ChatCompletionResponse: # This import should be safe at this point from mistralai.models import ( AssistantMessage, @@ -79,50 +84,79 @@ def mistral_accumulator( return acc -def mistral_stream_wrapper(name: str) -> Callable: +def mistral_stream_wrapper(settings: OpSettings) -> Callable: def wrapper(fn: Callable) -> Callable: - op = weave.op()(fn) + op_kwargs = settings.model_dump() + op = weave.op(fn, **op_kwargs) acc_op = add_accumulator(op, lambda inputs: mistral_accumulator) # type: ignore - acc_op.name = name # type: ignore return acc_op return wrapper -def mistral_wrapper(name: str) -> Callable: +def mistral_wrapper(settings: OpSettings) -> Callable: def wrapper(fn: Callable) -> Callable: - op = weave.op()(fn) - op.name = name # type: ignore + op_kwargs = settings.model_dump() + op = weave.op(fn, **op_kwargs) return op return wrapper -mistral_patcher = MultiPatcher( - [ - # Patch the sync, non-streaming chat method - SymbolPatcher( - lambda: importlib.import_module("mistralai.chat"), - "Chat.complete", - mistral_wrapper(name="Mistral.chat.complete"), - ), - # Patch the sync, streaming chat method - SymbolPatcher( - lambda: importlib.import_module("mistralai.chat"), - "Chat.stream", - mistral_stream_wrapper(name="Mistral.chat.stream"), - ), - # Patch the async, non-streaming chat method - SymbolPatcher( - lambda: importlib.import_module("mistralai.chat"), - "Chat.complete_async", - mistral_wrapper(name="Mistral.chat.complete_async"), - ), - # Patch the async, streaming chat method - SymbolPatcher( - lambda: importlib.import_module("mistralai.chat"), - "Chat.stream_async", - mistral_stream_wrapper(name="Mistral.chat.stream_async"), - ), - ] -) +def get_mistral_patcher( + settings: IntegrationSettings | None = None, +) -> MultiPatcher | NoOpPatcher: + if settings is None: + settings = IntegrationSettings() + + if not settings.enabled: + return NoOpPatcher() + + global _mistral_patcher + if _mistral_patcher is not None: + return _mistral_patcher + + base = settings.op_settings + chat_complete_settings = base.model_copy( + update={"name": base.name or "mistralai.chat.complete"} + ) + chat_stream_settings = base.model_copy( + update={"name": base.name or "mistralai.chat.stream"} + ) + async_chat_complete_settings = base.model_copy( + update={"name": base.name or "mistralai.async_client.chat.complete"} + ) + async_chat_stream_settings = base.model_copy( + update={"name": base.name or "mistralai.async_client.chat.stream"} + ) + + _mistral_patcher = MultiPatcher( + [ + # Patch the sync, non-streaming chat method + SymbolPatcher( + lambda: importlib.import_module("mistralai.chat"), + "Chat.complete", + mistral_wrapper(chat_complete_settings), + ), + # Patch the sync, streaming chat method + SymbolPatcher( + lambda: importlib.import_module("mistralai.chat"), + "Chat.stream", + mistral_stream_wrapper(chat_stream_settings), + ), + # Patch the async, non-streaming chat method + SymbolPatcher( + lambda: importlib.import_module("mistralai.chat"), + "Chat.complete_async", + mistral_wrapper(async_chat_complete_settings), + ), + # Patch the async, streaming chat method + SymbolPatcher( + lambda: importlib.import_module("mistralai.chat"), + "Chat.stream_async", + mistral_stream_wrapper(async_chat_stream_settings), + ), + ] + ) + + return _mistral_patcher diff --git a/weave/integrations/notdiamond/__init__.py b/weave/integrations/notdiamond/__init__.py index d99c31c4176d..8cb72ef2a55e 100644 --- a/weave/integrations/notdiamond/__init__.py +++ b/weave/integrations/notdiamond/__init__.py @@ -1 +1 @@ -from .tracing import notdiamond_patcher as notdiamond_patcher +from .tracing import get_notdiamond_patcher # noqa: F401 diff --git a/weave/integrations/notdiamond/tracing.py b/weave/integrations/notdiamond/tracing.py index 90c08b9e8c62..23589719be78 100644 --- a/weave/integrations/notdiamond/tracing.py +++ b/weave/integrations/notdiamond/tracing.py @@ -1,19 +1,32 @@ +from __future__ import annotations + import importlib from typing import Callable import weave -from weave.trace.patcher import MultiPatcher, SymbolPatcher +from weave.trace.autopatch import IntegrationSettings, OpSettings +from weave.trace.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher + +_notdiamond_patcher: MultiPatcher | None = None -def nd_wrapper(name: str) -> Callable[[Callable], Callable]: +def nd_wrapper(settings: OpSettings) -> Callable[[Callable], Callable]: def wrapper(fn: Callable) -> Callable: - op = weave.op()(fn) - op.name = name # type: ignore + op_kwargs = settings.model_dump() + op = weave.op(fn, **op_kwargs) return op return wrapper +def passthrough_wrapper(settings: OpSettings) -> Callable: + def wrapper(fn: Callable) -> Callable: + op_kwargs = settings.model_dump() + return weave.op(fn, **op_kwargs) + + return wrapper + + def _patch_client_op(method_name: str) -> list[SymbolPatcher]: return [ SymbolPatcher( @@ -29,36 +42,84 @@ def _patch_client_op(method_name: str) -> list[SymbolPatcher]: ] -patched_client_functions = _patch_client_op("model_select") - -patched_llmconfig_functions = [ - SymbolPatcher( - lambda: importlib.import_module("notdiamond"), - "LLMConfig.__init__", - weave.op(), - ), - SymbolPatcher( - lambda: importlib.import_module("notdiamond"), - "LLMConfig.from_string", - weave.op(), - ), -] - -patched_toolkit_functions = [ - SymbolPatcher( - lambda: importlib.import_module("notdiamond.toolkit.custom_router"), - "CustomRouter.fit", - weave.op(), - ), - SymbolPatcher( - lambda: importlib.import_module("notdiamond.toolkit.custom_router"), - "CustomRouter.eval", - weave.op(), - ), -] - -all_patched_functions = ( - patched_client_functions + patched_toolkit_functions + patched_llmconfig_functions -) - -notdiamond_patcher = MultiPatcher(all_patched_functions) +def get_notdiamond_patcher( + settings: IntegrationSettings | None = None, +) -> MultiPatcher | NoOpPatcher: + if settings is None: + settings = IntegrationSettings() + + if not settings.enabled: + return NoOpPatcher() + + global _notdiamond_patcher + if _notdiamond_patcher is not None: + return _notdiamond_patcher + + base = settings.op_settings + + model_select_settings = base.model_copy( + update={"name": base.name or "NotDiamond.model_select"} + ) + async_model_select_settings = base.model_copy( + update={"name": base.name or "NotDiamond.amodel_select"} + ) + patched_client_functions = [ + SymbolPatcher( + lambda: importlib.import_module("notdiamond"), + "NotDiamond.model_select", + passthrough_wrapper(model_select_settings), + ), + SymbolPatcher( + lambda: importlib.import_module("notdiamond"), + "NotDiamond.amodel_select", + passthrough_wrapper(async_model_select_settings), + ), + ] + + llm_config_init_settings = base.model_copy( + update={"name": base.name or "NotDiamond.LLMConfig.__init__"} + ) + llm_config_from_string_settings = base.model_copy( + update={"name": base.name or "NotDiamond.LLMConfig.from_string"} + ) + patched_llmconfig_functions = [ + SymbolPatcher( + lambda: importlib.import_module("notdiamond"), + "LLMConfig.__init__", + passthrough_wrapper(llm_config_init_settings), + ), + SymbolPatcher( + lambda: importlib.import_module("notdiamond"), + "LLMConfig.from_string", + passthrough_wrapper(llm_config_from_string_settings), + ), + ] + + toolkit_custom_router_fit_settings = base.model_copy( + update={"name": base.name or "NotDiamond.toolkit.custom_router.fit"} + ) + toolkit_custom_router_eval_settings = base.model_copy( + update={"name": base.name or "NotDiamond.toolkit.custom_router.eval"} + ) + patched_toolkit_functions = [ + SymbolPatcher( + lambda: importlib.import_module("notdiamond.toolkit.custom_router"), + "CustomRouter.fit", + passthrough_wrapper(toolkit_custom_router_fit_settings), + ), + SymbolPatcher( + lambda: importlib.import_module("notdiamond.toolkit.custom_router"), + "CustomRouter.eval", + passthrough_wrapper(toolkit_custom_router_eval_settings), + ), + ] + + all_patched_functions = ( + patched_client_functions + + patched_toolkit_functions + + patched_llmconfig_functions + ) + + _notdiamond_patcher = MultiPatcher(all_patched_functions) + + return _notdiamond_patcher diff --git a/weave/integrations/vertexai/vertexai_sdk.py b/weave/integrations/vertexai/vertexai_sdk.py index c2f7a9906c78..a64620cbe5fb 100644 --- a/weave/integrations/vertexai/vertexai_sdk.py +++ b/weave/integrations/vertexai/vertexai_sdk.py @@ -1,10 +1,13 @@ +from __future__ import annotations + import importlib from functools import wraps -from typing import TYPE_CHECKING, Any, Callable, Optional +from typing import TYPE_CHECKING, Any, Callable import weave +from weave.trace.autopatch import IntegrationSettings, OpSettings from weave.trace.op_extensions.accumulator import add_accumulator -from weave.trace.patcher import MultiPatcher, SymbolPatcher +from weave.trace.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher from weave.trace.serialize import dictify from weave.trace.weave_client import Call @@ -12,6 +15,9 @@ from vertexai.generative_models import GenerationResponse +_vertexai_patcher: MultiPatcher | None = None + + def vertexai_postprocess_inputs(inputs: dict[str, Any]) -> dict[str, Any]: if "self" in inputs: model_name = ( @@ -25,8 +31,8 @@ def vertexai_postprocess_inputs(inputs: dict[str, Any]) -> dict[str, Any]: def vertexai_accumulator( - acc: Optional["GenerationResponse"], value: "GenerationResponse" -) -> "GenerationResponse": + acc: GenerationResponse | None, value: GenerationResponse +) -> GenerationResponse: from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types from google.cloud.aiplatform_v1beta1.types import ( prediction_service as gapic_prediction_service_types, @@ -62,7 +68,7 @@ def vertexai_accumulator( def vertexai_on_finish( - call: Call, output: Any, exception: Optional[BaseException] + call: Call, output: Any, exception: BaseException | None ) -> None: original_model_name = call.inputs["model_name"] model_name = original_model_name.split("/")[-1] @@ -81,10 +87,13 @@ def vertexai_on_finish( call.summary.update(summary_update) -def vertexai_wrapper_sync(name: str) -> Callable[[Callable], Callable]: +def vertexai_wrapper_sync(settings: OpSettings) -> Callable[[Callable], Callable]: def wrapper(fn: Callable) -> Callable: - op = weave.op(postprocess_inputs=vertexai_postprocess_inputs)(fn) - op.name = name # type: ignore + op_kwargs = settings.model_copy() + if not op_kwargs.get("postprocess_inputs"): + op_kwargs["postprocess_inputs"] = vertexai_postprocess_inputs + + op = weave.op(fn, **op_kwargs) op._set_on_finish_handler(vertexai_on_finish) return add_accumulator( op, # type: ignore @@ -96,7 +105,7 @@ def wrapper(fn: Callable) -> Callable: return wrapper -def vertexai_wrapper_async(name: str) -> Callable[[Callable], Callable]: +def vertexai_wrapper_async(settings: OpSettings) -> Callable[[Callable], Callable]: def wrapper(fn: Callable) -> Callable: def _fn_wrapper(fn: Callable) -> Callable: @wraps(fn) @@ -105,9 +114,11 @@ async def _async_wrapper(*args: Any, **kwargs: Any) -> Any: return _async_wrapper - "We need to do this so we can check if `stream` is used" - op = weave.op(postprocess_inputs=vertexai_postprocess_inputs)(_fn_wrapper(fn)) - op.name = name # type: ignore + op_kwargs = settings.model_copy() + if not op_kwargs.get("postprocess_inputs"): + op_kwargs["postprocess_inputs"] = vertexai_postprocess_inputs + + op = weave.op(_fn_wrapper(fn), **op_kwargs) op._set_on_finish_handler(vertexai_on_finish) return add_accumulator( op, # type: ignore @@ -119,34 +130,65 @@ async def _async_wrapper(*args: Any, **kwargs: Any) -> Any: return wrapper -vertexai_patcher = MultiPatcher( - [ - SymbolPatcher( - lambda: importlib.import_module("vertexai.generative_models"), - "GenerativeModel.generate_content", - vertexai_wrapper_sync(name="vertexai.GenerativeModel.generate_content"), - ), - SymbolPatcher( - lambda: importlib.import_module("vertexai.generative_models"), - "GenerativeModel.generate_content_async", - vertexai_wrapper_async( - name="vertexai.GenerativeModel.generate_content_async" +def get_vertexai_patcher( + settings: IntegrationSettings | None = None, +) -> MultiPatcher | NoOpPatcher: + if settings is None: + settings = IntegrationSettings() + + if not settings.enabled: + return NoOpPatcher() + + global _vertexai_patcher + if _vertexai_patcher is not None: + return _vertexai_patcher + + base = settings.op_settings + + generate_content_settings = base.model_copy( + update={"name": base.name or "vertexai.GenerativeModel.generate_content"} + ) + generate_content_async_settings = base.model_copy( + update={"name": base.name or "vertexai.GenerativeModel.generate_content_async"} + ) + send_message_settings = base.model_copy( + update={"name": base.name or "vertexai.ChatSession.send_message"} + ) + send_message_async_settings = base.model_copy( + update={"name": base.name or "vertexai.ChatSession.send_message_async"} + ) + generate_images_settings = base.model_copy( + update={"name": base.name or "vertexai.ImageGenerationModel.generate_images"} + ) + + _vertexai_patcher = MultiPatcher( + [ + SymbolPatcher( + lambda: importlib.import_module("vertexai.generative_models"), + "GenerativeModel.generate_content", + vertexai_wrapper_sync(generate_content_settings), + ), + SymbolPatcher( + lambda: importlib.import_module("vertexai.generative_models"), + "GenerativeModel.generate_content_async", + vertexai_wrapper_async(generate_content_async_settings), ), - ), - SymbolPatcher( - lambda: importlib.import_module("vertexai.generative_models"), - "ChatSession.send_message", - vertexai_wrapper_sync(name="vertexai.ChatSession.send_message"), - ), - SymbolPatcher( - lambda: importlib.import_module("vertexai.generative_models"), - "ChatSession.send_message_async", - vertexai_wrapper_async(name="vertexai.ChatSession.send_message_async"), - ), - SymbolPatcher( - lambda: importlib.import_module("vertexai.preview.vision_models"), - "ImageGenerationModel.generate_images", - vertexai_wrapper_sync(name="vertexai.ImageGenerationModel.generate_images"), - ), - ] -) + SymbolPatcher( + lambda: importlib.import_module("vertexai.generative_models"), + "ChatSession.send_message", + vertexai_wrapper_sync(send_message_settings), + ), + SymbolPatcher( + lambda: importlib.import_module("vertexai.generative_models"), + "ChatSession.send_message_async", + vertexai_wrapper_async(send_message_async_settings), + ), + SymbolPatcher( + lambda: importlib.import_module("vertexai.preview.vision_models"), + "ImageGenerationModel.generate_images", + vertexai_wrapper_sync(generate_images_settings), + ), + ] + ) + + return _vertexai_patcher diff --git a/weave/trace/autopatch.py b/weave/trace/autopatch.py index 0619194a2247..c1c47d375127 100644 --- a/weave/trace/autopatch.py +++ b/weave/trace/autopatch.py @@ -34,89 +34,89 @@ class AutopatchSettings(BaseModel): # These will be uncommented as we add support for more integrations. Note that - # anthropic: IntegrationSettings = Field(default_factory=IntegrationSettings) - # cerebras: IntegrationSettings = Field(default_factory=IntegrationSettings) - # cohere: IntegrationSettings = Field(default_factory=IntegrationSettings) - # dspy: IntegrationSettings = Field(default_factory=IntegrationSettings) - # google_ai_studio: IntegrationSettings = Field(default_factory=IntegrationSettings) - # groq: IntegrationSettings = Field(default_factory=IntegrationSettings) - # instructor: IntegrationSettings = Field(default_factory=IntegrationSettings) - # langchain: IntegrationSettings = Field(default_factory=IntegrationSettings) - # litellm: IntegrationSettings = Field(default_factory=IntegrationSettings) - # llamaindex: IntegrationSettings = Field(default_factory=IntegrationSettings) - # mistral: IntegrationSettings = Field(default_factory=IntegrationSettings) - # notdiamond: IntegrationSettings = Field(default_factory=IntegrationSettings) + anthropic: IntegrationSettings = Field(default_factory=IntegrationSettings) + cerebras: IntegrationSettings = Field(default_factory=IntegrationSettings) + cohere: IntegrationSettings = Field(default_factory=IntegrationSettings) + dspy: IntegrationSettings = Field(default_factory=IntegrationSettings) + google_ai_studio: IntegrationSettings = Field(default_factory=IntegrationSettings) + groq: IntegrationSettings = Field(default_factory=IntegrationSettings) + instructor: IntegrationSettings = Field(default_factory=IntegrationSettings) + litellm: IntegrationSettings = Field(default_factory=IntegrationSettings) + mistral: IntegrationSettings = Field(default_factory=IntegrationSettings) + notdiamond: IntegrationSettings = Field(default_factory=IntegrationSettings) openai: IntegrationSettings = Field(default_factory=IntegrationSettings) - # vertexai: IntegrationSettings = Field(default_factory=IntegrationSettings) + vertexai: IntegrationSettings = Field(default_factory=IntegrationSettings) @validate_call def autopatch(settings: Optional[AutopatchSettings] = None) -> None: - from weave.integrations.anthropic.anthropic_sdk import anthropic_patcher - from weave.integrations.cerebras.cerebras_sdk import cerebras_patcher - from weave.integrations.cohere.cohere_sdk import cohere_patcher - from weave.integrations.dspy.dspy_sdk import dspy_patcher + from weave.integrations.anthropic.anthropic_sdk import get_anthropic_patcher + from weave.integrations.cerebras.cerebras_sdk import get_cerebras_patcher + from weave.integrations.cohere.cohere_sdk import get_cohere_patcher + from weave.integrations.dspy.dspy_sdk import get_dspy_patcher from weave.integrations.google_ai_studio.google_ai_studio_sdk import ( - google_genai_patcher, + get_google_genai_patcher, ) - from weave.integrations.groq.groq_sdk import groq_patcher - from weave.integrations.instructor.instructor_sdk import instructor_patcher + from weave.integrations.groq.groq_sdk import get_groq_patcher + from weave.integrations.instructor.instructor_sdk import get_instructor_patcher from weave.integrations.langchain.langchain import langchain_patcher - from weave.integrations.litellm.litellm import litellm_patcher + from weave.integrations.litellm.litellm import get_litellm_patcher from weave.integrations.llamaindex.llamaindex import llamaindex_patcher - from weave.integrations.mistral import mistral_patcher - from weave.integrations.notdiamond.tracing import notdiamond_patcher + from weave.integrations.mistral import get_mistral_patcher + from weave.integrations.notdiamond.tracing import get_notdiamond_patcher from weave.integrations.openai.openai_sdk import get_openai_patcher - from weave.integrations.vertexai.vertexai_sdk import vertexai_patcher + from weave.integrations.vertexai.vertexai_sdk import get_vertexai_patcher if settings is None: settings = AutopatchSettings() get_openai_patcher(settings.openai).attempt_patch() - mistral_patcher.attempt_patch() - litellm_patcher.attempt_patch() + get_mistral_patcher(settings.mistral).attempt_patch() + get_litellm_patcher(settings.litellm).attempt_patch() + get_anthropic_patcher(settings.anthropic).attempt_patch() + get_groq_patcher(settings.groq).attempt_patch() + get_instructor_patcher(settings.instructor).attempt_patch() + get_dspy_patcher(settings.dspy).attempt_patch() + get_cerebras_patcher(settings.cerebras).attempt_patch() + get_cohere_patcher(settings.cohere).attempt_patch() + get_google_genai_patcher(settings.google_ai_studio).attempt_patch() + get_notdiamond_patcher(settings.notdiamond).attempt_patch() + get_vertexai_patcher(settings.vertexai).attempt_patch() + llamaindex_patcher.attempt_patch() langchain_patcher.attempt_patch() - anthropic_patcher.attempt_patch() - groq_patcher.attempt_patch() - instructor_patcher.attempt_patch() - dspy_patcher.attempt_patch() - cerebras_patcher.attempt_patch() - cohere_patcher.attempt_patch() - google_genai_patcher.attempt_patch() - notdiamond_patcher.attempt_patch() - vertexai_patcher.attempt_patch() def reset_autopatch() -> None: - from weave.integrations.anthropic.anthropic_sdk import anthropic_patcher - from weave.integrations.cerebras.cerebras_sdk import cerebras_patcher - from weave.integrations.cohere.cohere_sdk import cohere_patcher - from weave.integrations.dspy.dspy_sdk import dspy_patcher + from weave.integrations.anthropic.anthropic_sdk import get_anthropic_patcher + from weave.integrations.cerebras.cerebras_sdk import get_cerebras_patcher + from weave.integrations.cohere.cohere_sdk import get_cohere_patcher + from weave.integrations.dspy.dspy_sdk import get_dspy_patcher from weave.integrations.google_ai_studio.google_ai_studio_sdk import ( - google_genai_patcher, + get_google_genai_patcher, ) - from weave.integrations.groq.groq_sdk import groq_patcher - from weave.integrations.instructor.instructor_sdk import instructor_patcher + from weave.integrations.groq.groq_sdk import get_groq_patcher + from weave.integrations.instructor.instructor_sdk import get_instructor_patcher from weave.integrations.langchain.langchain import langchain_patcher - from weave.integrations.litellm.litellm import litellm_patcher + from weave.integrations.litellm.litellm import get_litellm_patcher from weave.integrations.llamaindex.llamaindex import llamaindex_patcher - from weave.integrations.mistral import mistral_patcher - from weave.integrations.notdiamond.tracing import notdiamond_patcher + from weave.integrations.mistral import get_mistral_patcher + from weave.integrations.notdiamond.tracing import get_notdiamond_patcher from weave.integrations.openai.openai_sdk import get_openai_patcher - from weave.integrations.vertexai.vertexai_sdk import vertexai_patcher + from weave.integrations.vertexai.vertexai_sdk import get_vertexai_patcher get_openai_patcher().undo_patch() - mistral_patcher.undo_patch() - litellm_patcher.undo_patch() + get_mistral_patcher().undo_patch() + get_litellm_patcher().undo_patch() + get_anthropic_patcher().undo_patch() + get_groq_patcher().undo_patch() + get_instructor_patcher().undo_patch() + get_dspy_patcher().undo_patch() + get_cerebras_patcher().undo_patch() + get_cohere_patcher().undo_patch() + get_google_genai_patcher().undo_patch() + get_notdiamond_patcher().undo_patch() + get_vertexai_patcher().undo_patch() + llamaindex_patcher.undo_patch() langchain_patcher.undo_patch() - anthropic_patcher.undo_patch() - groq_patcher.undo_patch() - instructor_patcher.undo_patch() - dspy_patcher.undo_patch() - cerebras_patcher.undo_patch() - cohere_patcher.undo_patch() - google_genai_patcher.undo_patch() - notdiamond_patcher.undo_patch() - vertexai_patcher.undo_patch()