From da6fb3958aa8a1945dd50beaa72b5839d964f691 Mon Sep 17 00:00:00 2001 From: Andrew Truong Date: Wed, 11 Dec 2024 18:09:40 -0500 Subject: [PATCH] test --- weave/integrations/cerebras/cerebras_sdk.py | 8 +- weave/integrations/cohere/cohere_sdk.py | 6 +- .../google_ai_studio/google_ai_studio_sdk.py | 8 +- weave/integrations/groq/groq_sdk.py | 6 +- .../instructor/instructor_iterable_utils.py | 14 +- .../instructor/instructor_partial_utils.py | 7 +- .../integrations/instructor/instructor_sdk.py | 79 +++++++---- weave/integrations/litellm/litellm.py | 57 +++++--- weave/integrations/mistral/__init__.py | 4 +- weave/integrations/mistral/v0/mistral.py | 106 +++++++++----- weave/integrations/mistral/v1/mistral.py | 96 ++++++++----- weave/integrations/notdiamond/__init__.py | 2 +- weave/integrations/notdiamond/tracing.py | 131 +++++++++++++----- weave/integrations/vertexai/vertexai_sdk.py | 112 ++++++++++----- weave/trace/autopatch.py | 121 ++++++++-------- 15 files changed, 490 insertions(+), 267 deletions(-) diff --git a/weave/integrations/cerebras/cerebras_sdk.py b/weave/integrations/cerebras/cerebras_sdk.py index 15fa1cf7935..1b6fee49006 100644 --- a/weave/integrations/cerebras/cerebras_sdk.py +++ b/weave/integrations/cerebras/cerebras_sdk.py @@ -3,7 +3,7 @@ from typing import Any, Callable, Optional import weave -from weave.trace.autopatch import OpSettings +from weave.trace.autopatch import IntegrationSettings, OpSettings from weave.trace.patcher import MultiPatcher, SymbolPatcher _cerebras_patcher: Optional[MultiPatcher] = None @@ -34,14 +34,16 @@ async def _async_wrapper(*args: Any, **kwargs: Any) -> Any: return wrapper -def get_cerebras_patcher(settings: Optional[OpSettings] = None) -> MultiPatcher: +def get_cerebras_patcher( + settings: Optional[IntegrationSettings] = None, +) -> MultiPatcher: global _cerebras_patcher if _cerebras_patcher is not None: return _cerebras_patcher if settings is None: - settings = OpSettings() + settings = IntegrationSettings() base = settings.op_settings diff --git a/weave/integrations/cohere/cohere_sdk.py b/weave/integrations/cohere/cohere_sdk.py index aaaa7b10639..e554e659478 100644 --- a/weave/integrations/cohere/cohere_sdk.py +++ b/weave/integrations/cohere/cohere_sdk.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Any, Callable, Optional import weave -from weave.trace.autopatch import OpSettings +from weave.trace.autopatch import IntegrationSettings, OpSettings from weave.trace.op_extensions.accumulator import add_accumulator from weave.trace.patcher import MultiPatcher, SymbolPatcher @@ -185,14 +185,14 @@ def wrapper(fn: Callable) -> Callable: return wrapper -def get_cohere_patcher(settings: Optional[OpSettings] = None) -> MultiPatcher: +def get_cohere_patcher(settings: Optional[IntegrationSettings] = None) -> MultiPatcher: global _cohere_patcher if _cohere_patcher is not None: return _cohere_patcher if settings is None: - settings = OpSettings() + settings = IntegrationSettings() base = settings.op_settings diff --git a/weave/integrations/google_ai_studio/google_ai_studio_sdk.py b/weave/integrations/google_ai_studio/google_ai_studio_sdk.py index 68c2a0b39ea..d86fe5b11f3 100644 --- a/weave/integrations/google_ai_studio/google_ai_studio_sdk.py +++ b/weave/integrations/google_ai_studio/google_ai_studio_sdk.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Any, Callable, Optional import weave -from weave.trace.autopatch import OpSettings +from weave.trace.autopatch import IntegrationSettings, OpSettings from weave.trace.op_extensions.accumulator import add_accumulator from weave.trace.patcher import MultiPatcher, SymbolPatcher from weave.trace.serialize import dictify @@ -135,14 +135,16 @@ async def _async_wrapper(*args: Any, **kwargs: Any) -> Any: return wrapper -def get_google_genai_patcher(settings: Optional[OpSettings] = None) -> MultiPatcher: +def get_google_genai_patcher( + settings: Optional[IntegrationSettings] = None, +) -> MultiPatcher: global _google_genai_patcher if _google_genai_patcher is not None: return _google_genai_patcher if settings is None: - settings = OpSettings() + settings = IntegrationSettings() base = settings.op_settings diff --git a/weave/integrations/groq/groq_sdk.py b/weave/integrations/groq/groq_sdk.py index 2de5462da90..fe8db33bcf5 100644 --- a/weave/integrations/groq/groq_sdk.py +++ b/weave/integrations/groq/groq_sdk.py @@ -1,7 +1,7 @@ import importlib from typing import TYPE_CHECKING, Callable, Optional -from weave.trace.autopatch import OpSettings +from weave.trace.autopatch import IntegrationSettings, OpSettings if TYPE_CHECKING: from groq.types.chat import ChatCompletion, ChatCompletionChunk @@ -100,14 +100,14 @@ def wrapper(fn: Callable) -> Callable: return wrapper -def get_groq_patcher(settings: Optional[OpSettings] = None) -> MultiPatcher: +def get_groq_patcher(settings: Optional[IntegrationSettings] = None) -> MultiPatcher: global _groq_patcher if _groq_patcher is not None: return _groq_patcher if settings is None: - settings = OpSettings() + settings = IntegrationSettings() base = settings.op_settings diff --git a/weave/integrations/instructor/instructor_iterable_utils.py b/weave/integrations/instructor/instructor_iterable_utils.py index 84d64a103b6..3b0f128a132 100644 --- a/weave/integrations/instructor/instructor_iterable_utils.py +++ b/weave/integrations/instructor/instructor_iterable_utils.py @@ -4,6 +4,7 @@ from pydantic import BaseModel import weave +from weave.trace.autopatch import OpSettings from weave.trace.op_extensions.accumulator import add_accumulator @@ -27,10 +28,10 @@ def should_accumulate_iterable(inputs: dict) -> bool: return False -def instructor_wrapper_sync(name: str) -> Callable[[Callable], Callable]: +def instructor_wrapper_sync(settings: OpSettings) -> Callable[[Callable], Callable]: def wrapper(fn: Callable) -> Callable: - op = weave.op(fn) - op.name = name # type: ignore + op_kwargs = settings.model_dump() + op = weave.op(fn, **op_kwargs) return add_accumulator( op, # type: ignore make_accumulator=lambda inputs: instructor_iterable_accumulator, @@ -40,7 +41,7 @@ def wrapper(fn: Callable) -> Callable: return wrapper -def instructor_wrapper_async(name: str) -> Callable[[Callable], Callable]: +def instructor_wrapper_async(settings: OpSettings) -> Callable[[Callable], Callable]: def wrapper(fn: Callable) -> Callable: def _fn_wrapper(fn: Callable) -> Callable: @wraps(fn) @@ -49,9 +50,8 @@ async def _async_wrapper(*args: Any, **kwargs: Any) -> Any: return _async_wrapper - "We need to do this so we can check if `stream` is used" - op = weave.op(_fn_wrapper(fn)) - op.name = name # type: ignore + op_kwargs = settings.model_dump() + op = weave.op(_fn_wrapper(fn), **op_kwargs) return add_accumulator( op, # type: ignore make_accumulator=lambda inputs: instructor_iterable_accumulator, diff --git a/weave/integrations/instructor/instructor_partial_utils.py b/weave/integrations/instructor/instructor_partial_utils.py index f90dc7edb17..8efa84b302f 100644 --- a/weave/integrations/instructor/instructor_partial_utils.py +++ b/weave/integrations/instructor/instructor_partial_utils.py @@ -3,6 +3,7 @@ from pydantic import BaseModel import weave +from weave.trace.autopatch import OpSettings from weave.trace.op_extensions.accumulator import add_accumulator @@ -14,10 +15,10 @@ def instructor_partial_accumulator( return acc -def instructor_wrapper_partial(name: str) -> Callable[[Callable], Callable]: +def instructor_wrapper_partial(settings: OpSettings) -> Callable[[Callable], Callable]: def wrapper(fn: Callable) -> Callable: - op = weave.op()(fn) - op.name = name # type: ignore + op_kwargs = settings.model_dump() + op = weave.op(fn, **op_kwargs) return add_accumulator( op, # type: ignore make_accumulator=lambda inputs: instructor_partial_accumulator, diff --git a/weave/integrations/instructor/instructor_sdk.py b/weave/integrations/instructor/instructor_sdk.py index 867e9f2a785..2716d570cbb 100644 --- a/weave/integrations/instructor/instructor_sdk.py +++ b/weave/integrations/instructor/instructor_sdk.py @@ -1,31 +1,62 @@ import importlib +from typing import Optional +from weave.trace.autopatch import IntegrationSettings from weave.trace.patcher import MultiPatcher, SymbolPatcher from .instructor_iterable_utils import instructor_wrapper_async, instructor_wrapper_sync from .instructor_partial_utils import instructor_wrapper_partial -instructor_patcher = MultiPatcher( - [ - SymbolPatcher( - lambda: importlib.import_module("instructor.client"), - "Instructor.create", - instructor_wrapper_sync(name="Instructor.create"), - ), - SymbolPatcher( - lambda: importlib.import_module("instructor.client"), - "AsyncInstructor.create", - instructor_wrapper_async(name="AsyncInstructor.create"), - ), - SymbolPatcher( - lambda: importlib.import_module("instructor.client"), - "Instructor.create_partial", - instructor_wrapper_partial(name="Instructor.create_partial"), - ), - SymbolPatcher( - lambda: importlib.import_module("instructor.client"), - "AsyncInstructor.create_partial", - instructor_wrapper_partial(name="AsyncInstructor.create_partial"), - ), - ] -) +_instructor_patcher: Optional[MultiPatcher] = None + + +def get_instructor_patcher( + settings: Optional[IntegrationSettings] = None, +) -> MultiPatcher: + global _instructor_patcher + + if _instructor_patcher is not None: + return _instructor_patcher + + if settings is None: + settings = IntegrationSettings() + + base = settings.op_settings + + create_settings = base.model_copy(update={"name": base.name or "Instructor.create"}) + async_create_settings = base.model_copy( + update={"name": base.name or "AsyncInstructor.create"} + ) + create_partial_settings = base.model_copy( + update={"name": base.name or "Instructor.create_partial"} + ) + async_create_partial_settings = base.model_copy( + update={"name": base.name or "AsyncInstructor.create_partial"} + ) + + _instructor_patcher = MultiPatcher( + [ + SymbolPatcher( + lambda: importlib.import_module("instructor.client"), + "Instructor.create", + instructor_wrapper_sync(create_settings), + ), + SymbolPatcher( + lambda: importlib.import_module("instructor.client"), + "AsyncInstructor.create", + instructor_wrapper_async(async_create_settings), + ), + SymbolPatcher( + lambda: importlib.import_module("instructor.client"), + "Instructor.create_partial", + instructor_wrapper_partial(create_partial_settings), + ), + SymbolPatcher( + lambda: importlib.import_module("instructor.client"), + "AsyncInstructor.create_partial", + instructor_wrapper_partial(async_create_partial_settings), + ), + ] + ) + + return _instructor_patcher diff --git a/weave/integrations/litellm/litellm.py b/weave/integrations/litellm/litellm.py index c3bf1bf114a..8a1818820a0 100644 --- a/weave/integrations/litellm/litellm.py +++ b/weave/integrations/litellm/litellm.py @@ -2,12 +2,15 @@ from typing import TYPE_CHECKING, Any, Callable, Optional import weave +from weave.trace.autopatch import IntegrationSettings, OpSettings from weave.trace.op_extensions.accumulator import add_accumulator from weave.trace.patcher import MultiPatcher, SymbolPatcher if TYPE_CHECKING: from litellm.utils import ModelResponse +_litellm_patcher: Optional[MultiPatcher] = None + # This accumulator is nearly identical to the mistral accumulator, just with different types. def litellm_accumulator( @@ -82,10 +85,10 @@ def should_use_accumulator(inputs: dict) -> bool: return isinstance(inputs, dict) and bool(inputs.get("stream")) -def make_wrapper(name: str) -> Callable: +def make_wrapper(settings: OpSettings) -> Callable: def litellm_wrapper(fn: Callable) -> Callable: - op = weave.op()(fn) - op.name = name # type: ignore + op_kwargs = settings.model_dump() + op = weave.op(fn, **op_kwargs) return add_accumulator( op, # type: ignore make_accumulator=lambda inputs: litellm_accumulator, @@ -96,17 +99,37 @@ def litellm_wrapper(fn: Callable) -> Callable: return litellm_wrapper -litellm_patcher = MultiPatcher( - [ - SymbolPatcher( - lambda: importlib.import_module("litellm"), - "completion", - make_wrapper("litellm.completion"), - ), - SymbolPatcher( - lambda: importlib.import_module("litellm"), - "acompletion", - make_wrapper("litellm.acompletion"), - ), - ] -) +def get_litellm_patcher(settings: Optional[IntegrationSettings] = None) -> MultiPatcher: + global _litellm_patcher + + if _litellm_patcher is not None: + return _litellm_patcher + + if settings is None: + settings = IntegrationSettings() + + base = settings.op_settings + + completion_settings = base.model_copy( + update={"name": base.name or "litellm.completion"} + ) + acompletion_settings = base.model_copy( + update={"name": base.name or "litellm.acompletion"} + ) + + _litellm_patcher = MultiPatcher( + [ + SymbolPatcher( + lambda: importlib.import_module("litellm"), + "completion", + make_wrapper(completion_settings), + ), + SymbolPatcher( + lambda: importlib.import_module("litellm"), + "acompletion", + make_wrapper(acompletion_settings), + ), + ] + ) + + return _litellm_patcher diff --git a/weave/integrations/mistral/__init__.py b/weave/integrations/mistral/__init__.py index 34b40835efc..d78812d9afa 100644 --- a/weave/integrations/mistral/__init__.py +++ b/weave/integrations/mistral/__init__.py @@ -8,10 +8,10 @@ mistral_version = "1.0" # we need to return a patching function if version.parse(mistral_version) < version.parse("1.0.0"): - from .v0.mistral import mistral_patcher + from .v0.mistral import get_mistral_patcher # noqa: F401 print( f"Using MistralAI version {mistral_version}. Please consider upgrading to version 1.0.0 or later." ) else: - from .v1.mistral import mistral_patcher # noqa: F401 + from .v1.mistral import get_mistral_patcher # noqa: F401 diff --git a/weave/integrations/mistral/v0/mistral.py b/weave/integrations/mistral/v0/mistral.py index 6d4915eab41..70c81f449e2 100644 --- a/weave/integrations/mistral/v0/mistral.py +++ b/weave/integrations/mistral/v0/mistral.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING, Callable, Optional import weave +from weave.trace.autopatch import IntegrationSettings, OpSettings from weave.trace.op_extensions.accumulator import add_accumulator from weave.trace.patcher import MultiPatcher, SymbolPatcher @@ -11,6 +12,8 @@ ChatCompletionStreamResponse, ) +_mistral_patcher: Optional[MultiPatcher] = None + def mistral_accumulator( acc: Optional["ChatCompletionResponse"], @@ -72,37 +75,72 @@ def mistral_accumulator( return acc -def mistral_stream_wrapper(fn: Callable) -> Callable: - op = weave.op()(fn) - acc_op = add_accumulator(op, lambda inputs: mistral_accumulator) # type: ignore - return acc_op - - -mistral_patcher = MultiPatcher( - [ - # Patch the sync, non-streaming chat method - SymbolPatcher( - lambda: importlib.import_module("mistralai.client"), - "MistralClient.chat", - weave.op(), - ), - # Patch the sync, streaming chat method - SymbolPatcher( - lambda: importlib.import_module("mistralai.client"), - "MistralClient.chat_stream", - mistral_stream_wrapper, - ), - # Patch the async, non-streaming chat method - SymbolPatcher( - lambda: importlib.import_module("mistralai.async_client"), - "MistralAsyncClient.chat", - weave.op(), - ), - # Patch the async, streaming chat method - SymbolPatcher( - lambda: importlib.import_module("mistralai.async_client"), - "MistralAsyncClient.chat_stream", - mistral_stream_wrapper, - ), - ] -) +def mistral_stream_wrapper(settings: OpSettings) -> Callable: + def wrapper(fn: Callable) -> Callable: + op_kwargs = settings.model_dump() + op = weave.op(fn, **op_kwargs) + acc_op = add_accumulator(op, lambda inputs: mistral_accumulator) # type: ignore + return acc_op + + return wrapper + + +def mistral_wrapper(settings: OpSettings) -> Callable: + def wrapper(fn: Callable) -> Callable: + op_kwargs = settings.model_dump() + op = weave.op(fn, **op_kwargs) + return op + + return wrapper + + +def get_mistral_patcher(settings: Optional[IntegrationSettings] = None) -> MultiPatcher: + global _mistral_patcher + + if _mistral_patcher is not None: + return _mistral_patcher + + if settings is None: + settings = IntegrationSettings() + + base = settings.op_settings + + chat_settings = base.model_copy(update={"name": base.name or "mistralai.chat"}) + chat_stream_settings = base.model_copy( + update={"name": base.name or "mistralai.chat_stream"} + ) + async_chat_settings = base.model_copy( + update={"name": base.name or "mistralai.async_client.chat"} + ) + async_chat_stream_settings = base.model_copy( + update={"name": base.name or "mistralai.async_client.chat_stream"} + ) + + mistral_patcher = MultiPatcher( + [ + # Patch the sync, non-streaming chat method + SymbolPatcher( + lambda: importlib.import_module("mistralai.client"), + "MistralClient.chat", + mistral_wrapper(chat_settings), + ), + # Patch the sync, streaming chat method + SymbolPatcher( + lambda: importlib.import_module("mistralai.client"), + "MistralClient.chat_stream", + mistral_stream_wrapper(chat_stream_settings), + ), + # Patch the async, non-streaming chat method + SymbolPatcher( + lambda: importlib.import_module("mistralai.async_client"), + "MistralAsyncClient.chat", + mistral_wrapper(async_chat_settings), + ), + # Patch the async, streaming chat method + SymbolPatcher( + lambda: importlib.import_module("mistralai.async_client"), + "MistralAsyncClient.chat_stream", + mistral_stream_wrapper(async_chat_stream_settings), + ), + ] + ) diff --git a/weave/integrations/mistral/v1/mistral.py b/weave/integrations/mistral/v1/mistral.py index 692aa7b159b..a7824803c40 100644 --- a/weave/integrations/mistral/v1/mistral.py +++ b/weave/integrations/mistral/v1/mistral.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING, Callable, Optional import weave +from weave.trace.autopatch import IntegrationSettings, OpSettings from weave.trace.op_extensions.accumulator import add_accumulator from weave.trace.patcher import MultiPatcher, SymbolPatcher @@ -11,6 +12,8 @@ CompletionEvent, ) +_mistral_patcher: Optional[MultiPatcher] = None + def mistral_accumulator( acc: Optional["ChatCompletionResponse"], @@ -79,50 +82,75 @@ def mistral_accumulator( return acc -def mistral_stream_wrapper(name: str) -> Callable: +def mistral_stream_wrapper(settings: OpSettings) -> Callable: def wrapper(fn: Callable) -> Callable: - op = weave.op()(fn) + op_kwargs = settings.model_dump() + op = weave.op(fn, **op_kwargs) acc_op = add_accumulator(op, lambda inputs: mistral_accumulator) # type: ignore - acc_op.name = name # type: ignore return acc_op return wrapper -def mistral_wrapper(name: str) -> Callable: +def mistral_wrapper(settings: OpSettings) -> Callable: def wrapper(fn: Callable) -> Callable: - op = weave.op()(fn) - op.name = name # type: ignore + op_kwargs = settings.model_dump() + op = weave.op(fn, **op_kwargs) return op return wrapper -mistral_patcher = MultiPatcher( - [ - # Patch the sync, non-streaming chat method - SymbolPatcher( - lambda: importlib.import_module("mistralai.chat"), - "Chat.complete", - mistral_wrapper(name="Mistral.chat.complete"), - ), - # Patch the sync, streaming chat method - SymbolPatcher( - lambda: importlib.import_module("mistralai.chat"), - "Chat.stream", - mistral_stream_wrapper(name="Mistral.chat.stream"), - ), - # Patch the async, non-streaming chat method - SymbolPatcher( - lambda: importlib.import_module("mistralai.chat"), - "Chat.complete_async", - mistral_wrapper(name="Mistral.chat.complete_async"), - ), - # Patch the async, streaming chat method - SymbolPatcher( - lambda: importlib.import_module("mistralai.chat"), - "Chat.stream_async", - mistral_stream_wrapper(name="Mistral.chat.stream_async"), - ), - ] -) +def get_mistral_patcher(settings: Optional[IntegrationSettings] = None) -> MultiPatcher: + global _mistral_patcher + + if _mistral_patcher is not None: + return _mistral_patcher + + if settings is None: + settings = IntegrationSettings() + + base = settings.op_settings + chat_complete_settings = base.model_copy( + update={"name": base.name or "mistralai.chat.complete"} + ) + chat_stream_settings = base.model_copy( + update={"name": base.name or "mistralai.chat.stream"} + ) + async_chat_complete_settings = base.model_copy( + update={"name": base.name or "mistralai.async_client.chat.complete"} + ) + async_chat_stream_settings = base.model_copy( + update={"name": base.name or "mistralai.async_client.chat.stream"} + ) + + _mistral_patcher = MultiPatcher( + [ + # Patch the sync, non-streaming chat method + SymbolPatcher( + lambda: importlib.import_module("mistralai.chat"), + "Chat.complete", + mistral_wrapper(chat_complete_settings), + ), + # Patch the sync, streaming chat method + SymbolPatcher( + lambda: importlib.import_module("mistralai.chat"), + "Chat.stream", + mistral_stream_wrapper(chat_stream_settings), + ), + # Patch the async, non-streaming chat method + SymbolPatcher( + lambda: importlib.import_module("mistralai.chat"), + "Chat.complete_async", + mistral_wrapper(async_chat_complete_settings), + ), + # Patch the async, streaming chat method + SymbolPatcher( + lambda: importlib.import_module("mistralai.chat"), + "Chat.stream_async", + mistral_stream_wrapper(async_chat_stream_settings), + ), + ] + ) + + return _mistral_patcher diff --git a/weave/integrations/notdiamond/__init__.py b/weave/integrations/notdiamond/__init__.py index d99c31c4176..8cb72ef2a55 100644 --- a/weave/integrations/notdiamond/__init__.py +++ b/weave/integrations/notdiamond/__init__.py @@ -1 +1 @@ -from .tracing import notdiamond_patcher as notdiamond_patcher +from .tracing import get_notdiamond_patcher # noqa: F401 diff --git a/weave/integrations/notdiamond/tracing.py b/weave/integrations/notdiamond/tracing.py index 90c08b9e8c6..ae1e5217725 100644 --- a/weave/integrations/notdiamond/tracing.py +++ b/weave/integrations/notdiamond/tracing.py @@ -1,19 +1,30 @@ import importlib -from typing import Callable +from typing import Callable, Optional import weave +from weave.trace.autopatch import IntegrationSettings, OpSettings from weave.trace.patcher import MultiPatcher, SymbolPatcher +_notdiamond_patcher: Optional[MultiPatcher] = None -def nd_wrapper(name: str) -> Callable[[Callable], Callable]: + +def nd_wrapper(settings: OpSettings) -> Callable[[Callable], Callable]: def wrapper(fn: Callable) -> Callable: - op = weave.op()(fn) - op.name = name # type: ignore + op_kwargs = settings.model_dump() + op = weave.op(fn, **op_kwargs) return op return wrapper +def passthrough_wrapper(settings: OpSettings) -> Callable: + def wrapper(fn: Callable) -> Callable: + op_kwargs = settings.model_dump() + return weave.op(fn, **op_kwargs) + + return wrapper + + def _patch_client_op(method_name: str) -> list[SymbolPatcher]: return [ SymbolPatcher( @@ -29,36 +40,82 @@ def _patch_client_op(method_name: str) -> list[SymbolPatcher]: ] -patched_client_functions = _patch_client_op("model_select") - -patched_llmconfig_functions = [ - SymbolPatcher( - lambda: importlib.import_module("notdiamond"), - "LLMConfig.__init__", - weave.op(), - ), - SymbolPatcher( - lambda: importlib.import_module("notdiamond"), - "LLMConfig.from_string", - weave.op(), - ), -] - -patched_toolkit_functions = [ - SymbolPatcher( - lambda: importlib.import_module("notdiamond.toolkit.custom_router"), - "CustomRouter.fit", - weave.op(), - ), - SymbolPatcher( - lambda: importlib.import_module("notdiamond.toolkit.custom_router"), - "CustomRouter.eval", - weave.op(), - ), -] - -all_patched_functions = ( - patched_client_functions + patched_toolkit_functions + patched_llmconfig_functions -) - -notdiamond_patcher = MultiPatcher(all_patched_functions) +def get_notdiamond_patcher( + settings: Optional[IntegrationSettings] = None, +) -> MultiPatcher: + global _notdiamond_patcher + + if _notdiamond_patcher is not None: + return _notdiamond_patcher + + if settings is None: + settings = IntegrationSettings() + + base = settings.op_settings + + model_select_settings = base.model_copy( + update={"name": base.name or "NotDiamond.model_select"} + ) + async_model_select_settings = base.model_copy( + update={"name": base.name or "NotDiamond.amodel_select"} + ) + patched_client_functions = [ + SymbolPatcher( + lambda: importlib.import_module("notdiamond"), + "NotDiamond.model_select", + passthrough_wrapper(model_select_settings), + ), + SymbolPatcher( + lambda: importlib.import_module("notdiamond"), + "NotDiamond.amodel_select", + passthrough_wrapper(async_model_select_settings), + ), + ] + + llm_config_init_settings = base.model_copy( + update={"name": base.name or "NotDiamond.LLMConfig.__init__"} + ) + llm_config_from_string_settings = base.model_copy( + update={"name": base.name or "NotDiamond.LLMConfig.from_string"} + ) + patched_llmconfig_functions = [ + SymbolPatcher( + lambda: importlib.import_module("notdiamond"), + "LLMConfig.__init__", + passthrough_wrapper(llm_config_init_settings), + ), + SymbolPatcher( + lambda: importlib.import_module("notdiamond"), + "LLMConfig.from_string", + passthrough_wrapper(llm_config_from_string_settings), + ), + ] + + toolkit_custom_router_fit_settings = base.model_copy( + update={"name": base.name or "NotDiamond.toolkit.custom_router.fit"} + ) + toolkit_custom_router_eval_settings = base.model_copy( + update={"name": base.name or "NotDiamond.toolkit.custom_router.eval"} + ) + patched_toolkit_functions = [ + SymbolPatcher( + lambda: importlib.import_module("notdiamond.toolkit.custom_router"), + "CustomRouter.fit", + passthrough_wrapper(toolkit_custom_router_fit_settings), + ), + SymbolPatcher( + lambda: importlib.import_module("notdiamond.toolkit.custom_router"), + "CustomRouter.eval", + passthrough_wrapper(toolkit_custom_router_eval_settings), + ), + ] + + all_patched_functions = ( + patched_client_functions + + patched_toolkit_functions + + patched_llmconfig_functions + ) + + _notdiamond_patcher = MultiPatcher(all_patched_functions) + + return _notdiamond_patcher diff --git a/weave/integrations/vertexai/vertexai_sdk.py b/weave/integrations/vertexai/vertexai_sdk.py index c2f7a9906c7..f016135d775 100644 --- a/weave/integrations/vertexai/vertexai_sdk.py +++ b/weave/integrations/vertexai/vertexai_sdk.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Any, Callable, Optional import weave +from weave.trace.autopatch import IntegrationSettings, OpSettings from weave.trace.op_extensions.accumulator import add_accumulator from weave.trace.patcher import MultiPatcher, SymbolPatcher from weave.trace.serialize import dictify @@ -12,6 +13,9 @@ from vertexai.generative_models import GenerationResponse +_vertexai_patcher: Optional[MultiPatcher] = None + + def vertexai_postprocess_inputs(inputs: dict[str, Any]) -> dict[str, Any]: if "self" in inputs: model_name = ( @@ -81,10 +85,13 @@ def vertexai_on_finish( call.summary.update(summary_update) -def vertexai_wrapper_sync(name: str) -> Callable[[Callable], Callable]: +def vertexai_wrapper_sync(settings: OpSettings) -> Callable[[Callable], Callable]: def wrapper(fn: Callable) -> Callable: - op = weave.op(postprocess_inputs=vertexai_postprocess_inputs)(fn) - op.name = name # type: ignore + op_kwargs = settings.model_copy() + if not op_kwargs.get("postprocess_inputs"): + op_kwargs["postprocess_inputs"] = vertexai_postprocess_inputs + + op = weave.op(fn, **op_kwargs) op._set_on_finish_handler(vertexai_on_finish) return add_accumulator( op, # type: ignore @@ -96,7 +103,7 @@ def wrapper(fn: Callable) -> Callable: return wrapper -def vertexai_wrapper_async(name: str) -> Callable[[Callable], Callable]: +def vertexai_wrapper_async(settings: OpSettings) -> Callable[[Callable], Callable]: def wrapper(fn: Callable) -> Callable: def _fn_wrapper(fn: Callable) -> Callable: @wraps(fn) @@ -105,9 +112,11 @@ async def _async_wrapper(*args: Any, **kwargs: Any) -> Any: return _async_wrapper - "We need to do this so we can check if `stream` is used" - op = weave.op(postprocess_inputs=vertexai_postprocess_inputs)(_fn_wrapper(fn)) - op.name = name # type: ignore + op_kwargs = settings.model_copy() + if not op_kwargs.get("postprocess_inputs"): + op_kwargs["postprocess_inputs"] = vertexai_postprocess_inputs + + op = weave.op(_fn_wrapper(fn), **op_kwargs) op._set_on_finish_handler(vertexai_on_finish) return add_accumulator( op, # type: ignore @@ -119,34 +128,63 @@ async def _async_wrapper(*args: Any, **kwargs: Any) -> Any: return wrapper -vertexai_patcher = MultiPatcher( - [ - SymbolPatcher( - lambda: importlib.import_module("vertexai.generative_models"), - "GenerativeModel.generate_content", - vertexai_wrapper_sync(name="vertexai.GenerativeModel.generate_content"), - ), - SymbolPatcher( - lambda: importlib.import_module("vertexai.generative_models"), - "GenerativeModel.generate_content_async", - vertexai_wrapper_async( - name="vertexai.GenerativeModel.generate_content_async" +def get_vertexai_patcher( + settings: Optional[IntegrationSettings] = None, +) -> MultiPatcher: + global _vertexai_patcher + + if _vertexai_patcher is not None: + return _vertexai_patcher + + if settings is None: + settings = IntegrationSettings() + + base = settings.op_settings + + generate_content_settings = base.model_copy( + update={"name": base.name or "vertexai.GenerativeModel.generate_content"} + ) + generate_content_async_settings = base.model_copy( + update={"name": base.name or "vertexai.GenerativeModel.generate_content_async"} + ) + send_message_settings = base.model_copy( + update={"name": base.name or "vertexai.ChatSession.send_message"} + ) + send_message_async_settings = base.model_copy( + update={"name": base.name or "vertexai.ChatSession.send_message_async"} + ) + generate_images_settings = base.model_copy( + update={"name": base.name or "vertexai.ImageGenerationModel.generate_images"} + ) + + _vertexai_patcher = MultiPatcher( + [ + SymbolPatcher( + lambda: importlib.import_module("vertexai.generative_models"), + "GenerativeModel.generate_content", + vertexai_wrapper_sync(generate_content_settings), + ), + SymbolPatcher( + lambda: importlib.import_module("vertexai.generative_models"), + "GenerativeModel.generate_content_async", + vertexai_wrapper_async(generate_content_async_settings), ), - ), - SymbolPatcher( - lambda: importlib.import_module("vertexai.generative_models"), - "ChatSession.send_message", - vertexai_wrapper_sync(name="vertexai.ChatSession.send_message"), - ), - SymbolPatcher( - lambda: importlib.import_module("vertexai.generative_models"), - "ChatSession.send_message_async", - vertexai_wrapper_async(name="vertexai.ChatSession.send_message_async"), - ), - SymbolPatcher( - lambda: importlib.import_module("vertexai.preview.vision_models"), - "ImageGenerationModel.generate_images", - vertexai_wrapper_sync(name="vertexai.ImageGenerationModel.generate_images"), - ), - ] -) + SymbolPatcher( + lambda: importlib.import_module("vertexai.generative_models"), + "ChatSession.send_message", + vertexai_wrapper_sync(send_message_settings), + ), + SymbolPatcher( + lambda: importlib.import_module("vertexai.generative_models"), + "ChatSession.send_message_async", + vertexai_wrapper_async(send_message_async_settings), + ), + SymbolPatcher( + lambda: importlib.import_module("vertexai.preview.vision_models"), + "ImageGenerationModel.generate_images", + vertexai_wrapper_sync(generate_images_settings), + ), + ] + ) + + return _vertexai_patcher diff --git a/weave/trace/autopatch.py b/weave/trace/autopatch.py index 0619194a224..c79cf3b315d 100644 --- a/weave/trace/autopatch.py +++ b/weave/trace/autopatch.py @@ -34,89 +34,92 @@ class AutopatchSettings(BaseModel): # These will be uncommented as we add support for more integrations. Note that - # anthropic: IntegrationSettings = Field(default_factory=IntegrationSettings) - # cerebras: IntegrationSettings = Field(default_factory=IntegrationSettings) - # cohere: IntegrationSettings = Field(default_factory=IntegrationSettings) - # dspy: IntegrationSettings = Field(default_factory=IntegrationSettings) - # google_ai_studio: IntegrationSettings = Field(default_factory=IntegrationSettings) - # groq: IntegrationSettings = Field(default_factory=IntegrationSettings) - # instructor: IntegrationSettings = Field(default_factory=IntegrationSettings) - # langchain: IntegrationSettings = Field(default_factory=IntegrationSettings) - # litellm: IntegrationSettings = Field(default_factory=IntegrationSettings) - # llamaindex: IntegrationSettings = Field(default_factory=IntegrationSettings) - # mistral: IntegrationSettings = Field(default_factory=IntegrationSettings) - # notdiamond: IntegrationSettings = Field(default_factory=IntegrationSettings) + anthropic: IntegrationSettings = Field(default_factory=IntegrationSettings) + cerebras: IntegrationSettings = Field(default_factory=IntegrationSettings) + cohere: IntegrationSettings = Field(default_factory=IntegrationSettings) + dspy: IntegrationSettings = Field(default_factory=IntegrationSettings) + google_ai_studio: IntegrationSettings = Field(default_factory=IntegrationSettings) + groq: IntegrationSettings = Field(default_factory=IntegrationSettings) + instructor: IntegrationSettings = Field(default_factory=IntegrationSettings) + langchain: IntegrationSettings = Field(default_factory=IntegrationSettings) + litellm: IntegrationSettings = Field(default_factory=IntegrationSettings) + llamaindex: IntegrationSettings = Field(default_factory=IntegrationSettings) + mistral: IntegrationSettings = Field(default_factory=IntegrationSettings) + notdiamond: IntegrationSettings = Field(default_factory=IntegrationSettings) openai: IntegrationSettings = Field(default_factory=IntegrationSettings) - # vertexai: IntegrationSettings = Field(default_factory=IntegrationSettings) + vertexai: IntegrationSettings = Field(default_factory=IntegrationSettings) @validate_call def autopatch(settings: Optional[AutopatchSettings] = None) -> None: - from weave.integrations.anthropic.anthropic_sdk import anthropic_patcher - from weave.integrations.cerebras.cerebras_sdk import cerebras_patcher - from weave.integrations.cohere.cohere_sdk import cohere_patcher - from weave.integrations.dspy.dspy_sdk import dspy_patcher + from weave.integrations.anthropic.anthropic_sdk import get_anthropic_patcher + from weave.integrations.cerebras.cerebras_sdk import get_cerebras_patcher + from weave.integrations.cohere.cohere_sdk import get_cohere_patcher + from weave.integrations.dspy.dspy_sdk import get_dspy_patcher from weave.integrations.google_ai_studio.google_ai_studio_sdk import ( - google_genai_patcher, + get_google_genai_patcher, ) - from weave.integrations.groq.groq_sdk import groq_patcher - from weave.integrations.instructor.instructor_sdk import instructor_patcher + from weave.integrations.groq.groq_sdk import get_groq_patcher + from weave.integrations.instructor.instructor_sdk import get_instructor_patcher from weave.integrations.langchain.langchain import langchain_patcher - from weave.integrations.litellm.litellm import litellm_patcher + from weave.integrations.litellm.litellm import get_litellm_patcher from weave.integrations.llamaindex.llamaindex import llamaindex_patcher - from weave.integrations.mistral import mistral_patcher - from weave.integrations.notdiamond.tracing import notdiamond_patcher + from weave.integrations.mistral import get_mistral_patcher + from weave.integrations.notdiamond.tracing import get_notdiamond_patcher from weave.integrations.openai.openai_sdk import get_openai_patcher - from weave.integrations.vertexai.vertexai_sdk import vertexai_patcher + from weave.integrations.vertexai.vertexai_sdk import get_vertexai_patcher if settings is None: settings = AutopatchSettings() get_openai_patcher(settings.openai).attempt_patch() - mistral_patcher.attempt_patch() - litellm_patcher.attempt_patch() - llamaindex_patcher.attempt_patch() + get_mistral_patcher(settings.mistral).attempt_patch() + get_litellm_patcher(settings.litellm).attempt_patch() + get_anthropic_patcher(settings.anthropic).attempt_patch() + get_groq_patcher(settings.groq).attempt_patch() + get_instructor_patcher(settings.instructor).attempt_patch() + get_dspy_patcher(settings.dspy).attempt_patch() + get_cerebras_patcher(settings.cerebras).attempt_patch() + get_cohere_patcher(settings.cohere).attempt_patch() + get_google_genai_patcher(settings.google_ai_studio).attempt_patch() + get_notdiamond_patcher(settings.notdiamond).attempt_patch() + get_vertexai_patcher(settings.vertexai).attempt_patch() + + # These integrations don't use the op decorator, so there are no patching settings for them langchain_patcher.attempt_patch() - anthropic_patcher.attempt_patch() - groq_patcher.attempt_patch() - instructor_patcher.attempt_patch() - dspy_patcher.attempt_patch() - cerebras_patcher.attempt_patch() - cohere_patcher.attempt_patch() - google_genai_patcher.attempt_patch() - notdiamond_patcher.attempt_patch() - vertexai_patcher.attempt_patch() + llamaindex_patcher.attempt_patch() def reset_autopatch() -> None: - from weave.integrations.anthropic.anthropic_sdk import anthropic_patcher - from weave.integrations.cerebras.cerebras_sdk import cerebras_patcher - from weave.integrations.cohere.cohere_sdk import cohere_patcher - from weave.integrations.dspy.dspy_sdk import dspy_patcher + from weave.integrations.anthropic.anthropic_sdk import get_anthropic_patcher + from weave.integrations.cerebras.cerebras_sdk import get_cerebras_patcher + from weave.integrations.cohere.cohere_sdk import get_cohere_patcher + from weave.integrations.dspy.dspy_sdk import get_dspy_patcher from weave.integrations.google_ai_studio.google_ai_studio_sdk import ( - google_genai_patcher, + get_google_genai_patcher, ) - from weave.integrations.groq.groq_sdk import groq_patcher - from weave.integrations.instructor.instructor_sdk import instructor_patcher + from weave.integrations.groq.groq_sdk import get_groq_patcher + from weave.integrations.instructor.instructor_sdk import get_instructor_patcher from weave.integrations.langchain.langchain import langchain_patcher - from weave.integrations.litellm.litellm import litellm_patcher + from weave.integrations.litellm.litellm import get_litellm_patcher from weave.integrations.llamaindex.llamaindex import llamaindex_patcher - from weave.integrations.mistral import mistral_patcher - from weave.integrations.notdiamond.tracing import notdiamond_patcher + from weave.integrations.mistral import get_mistral_patcher + from weave.integrations.notdiamond.tracing import get_notdiamond_patcher from weave.integrations.openai.openai_sdk import get_openai_patcher - from weave.integrations.vertexai.vertexai_sdk import vertexai_patcher + from weave.integrations.vertexai.vertexai_sdk import get_vertexai_patcher get_openai_patcher().undo_patch() - mistral_patcher.undo_patch() - litellm_patcher.undo_patch() - llamaindex_patcher.undo_patch() + get_mistral_patcher().undo_patch() + get_litellm_patcher().undo_patch() + get_anthropic_patcher().undo_patch() + get_groq_patcher().undo_patch() + get_instructor_patcher().undo_patch() + get_dspy_patcher().undo_patch() + get_cerebras_patcher().undo_patch() + get_cohere_patcher().undo_patch() + get_google_genai_patcher().undo_patch() + get_notdiamond_patcher().undo_patch() + get_vertexai_patcher().undo_patch() + langchain_patcher.undo_patch() - anthropic_patcher.undo_patch() - groq_patcher.undo_patch() - instructor_patcher.undo_patch() - dspy_patcher.undo_patch() - cerebras_patcher.undo_patch() - cohere_patcher.undo_patch() - google_genai_patcher.undo_patch() - notdiamond_patcher.undo_patch() - vertexai_patcher.undo_patch() + llamaindex_patcher.undo_patch()