From 3a3fb64b32a188d52a1ca5b4c17c375d0c6343d0 Mon Sep 17 00:00:00 2001 From: Andrew Truong Date: Wed, 11 Dec 2024 17:50:27 -0500 Subject: [PATCH 1/5] test --- tests/integrations/litellm/litellm_test.py | 6 +- weave/integrations/anthropic/anthropic_sdk.py | 113 ++++++----- weave/integrations/cerebras/cerebras_sdk.py | 66 ++++--- weave/integrations/cohere/__init__.py | 2 +- weave/integrations/cohere/cohere_sdk.py | 178 +++++++++++------- .../google_ai_studio/google_ai_studio_sdk.py | 111 +++++++---- weave/integrations/groq/groq_sdk.py | 59 ++++-- 7 files changed, 337 insertions(+), 198 deletions(-) diff --git a/tests/integrations/litellm/litellm_test.py b/tests/integrations/litellm/litellm_test.py index 8cc0966c476..ffd1094e653 100644 --- a/tests/integrations/litellm/litellm_test.py +++ b/tests/integrations/litellm/litellm_test.py @@ -8,7 +8,7 @@ from packaging.version import parse as version_parse import weave -from weave.integrations.litellm.litellm import litellm_patcher +from weave.integrations.litellm.litellm import get_litellm_patcher # This PR: # https://github.com/BerriAI/litellm/commit/fe2aa706e8ff4edbcd109897e5da6b83ef6ad693 @@ -38,9 +38,9 @@ def patch_litellm(request: Any) -> Generator[None, None, None]: yield return - litellm_patcher.attempt_patch() + get_litellm_patcher().attempt_patch() yield - litellm_patcher.undo_patch() + get_litellm_patcher().undo_patch() @pytest.mark.skip_clickhouse_client # TODO:VCR recording does not seem to allow us to make requests to the clickhouse db in non-recording mode diff --git a/weave/integrations/anthropic/anthropic_sdk.py b/weave/integrations/anthropic/anthropic_sdk.py index 6e33e1a3906..e583e8de1bc 100644 --- a/weave/integrations/anthropic/anthropic_sdk.py +++ b/weave/integrations/anthropic/anthropic_sdk.py @@ -1,15 +1,10 @@ import importlib from collections.abc import AsyncIterator, Iterator from functools import wraps -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Optional, - Union, -) +from typing import TYPE_CHECKING, Any, Callable, Optional, Union import weave +from weave.trace.autopatch import IntegrationSettings, OpSettings from weave.trace.op_extensions.accumulator import _IteratorWrapper, add_accumulator from weave.trace.patcher import MultiPatcher, SymbolPatcher @@ -17,6 +12,8 @@ from anthropic.lib.streaming import MessageStream from anthropic.types import Message, MessageStreamEvent +_anthropic_patcher: Optional[MultiPatcher] = None + def anthropic_accumulator( acc: Optional["Message"], @@ -73,13 +70,11 @@ def should_use_accumulator(inputs: dict) -> bool: return isinstance(inputs, dict) and bool(inputs.get("stream")) -def create_wrapper_sync( - name: str, -) -> Callable[[Callable], Callable]: +def create_wrapper_sync(settings: OpSettings) -> Callable[[Callable], Callable]: def wrapper(fn: Callable) -> Callable: "We need to do this so we can check if `stream` is used" - op = weave.op()(fn) - op.name = name # type: ignore + op_kwargs = settings.model_dump() + op = weave.op(fn, **op_kwargs) return add_accumulator( op, # type: ignore make_accumulator=lambda inputs: anthropic_accumulator, @@ -92,9 +87,7 @@ def wrapper(fn: Callable) -> Callable: # Surprisingly, the async `client.chat.completions.create` does not pass # `inspect.iscoroutinefunction`, so we can't dispatch on it and must write # it manually here... -def create_wrapper_async( - name: str, -) -> Callable[[Callable], Callable]: +def create_wrapper_async(settings: OpSettings) -> Callable[[Callable], Callable]: def wrapper(fn: Callable) -> Callable: def _fn_wrapper(fn: Callable) -> Callable: @wraps(fn) @@ -104,8 +97,8 @@ async def _async_wrapper(*args: Any, **kwargs: Any) -> Any: return _async_wrapper "We need to do this so we can check if `stream` is used" - op = weave.op()(_fn_wrapper(fn)) - op.name = name # type: ignore + op_kwargs = settings.model_dump() + op = weave.op(_fn_wrapper(fn), **op_kwargs) return add_accumulator( op, # type: ignore make_accumulator=lambda inputs: anthropic_accumulator, @@ -171,12 +164,10 @@ def text_stream(self) -> Union[Iterator[str], AsyncIterator[str]]: return self.__stream_text__() -def create_stream_wrapper( - name: str, -) -> Callable[[Callable], Callable]: +def create_stream_wrapper(settings: OpSettings) -> Callable[[Callable], Callable]: def wrapper(fn: Callable) -> Callable: - op = weave.op()(fn) - op.name = name # type: ignore + op_kwargs = settings.model_dump() + op = weave.op(fn, **op_kwargs) return add_accumulator( op, # type: ignore make_accumulator=lambda _: anthropic_stream_accumulator, @@ -187,28 +178,56 @@ def wrapper(fn: Callable) -> Callable: return wrapper -anthropic_patcher = MultiPatcher( - [ - # Patch the sync messages.create method for all messages.create methods - SymbolPatcher( - lambda: importlib.import_module("anthropic.resources.messages"), - "Messages.create", - create_wrapper_sync(name="anthropic.Messages.create"), - ), - SymbolPatcher( - lambda: importlib.import_module("anthropic.resources.messages"), - "AsyncMessages.create", - create_wrapper_async(name="anthropic.AsyncMessages.create"), - ), - SymbolPatcher( - lambda: importlib.import_module("anthropic.resources.messages"), - "Messages.stream", - create_stream_wrapper(name="anthropic.Messages.stream"), - ), - SymbolPatcher( - lambda: importlib.import_module("anthropic.resources.messages"), - "AsyncMessages.stream", - create_stream_wrapper(name="anthropic.AsyncMessages.stream"), - ), - ] -) +def get_anthropic_patcher( + settings: Optional[IntegrationSettings] = None, +) -> MultiPatcher: + global _anthropic_patcher + + if _anthropic_patcher is not None: + return _anthropic_patcher + + if settings is None: + settings = IntegrationSettings() + + base = settings.op_settings + + messages_create_settings = base.model_copy( + update={"name": base.name or "anthropic.Messages.create"} + ) + async_messages_create_settings = base.model_copy( + update={"name": base.name or "anthropic.AsyncMessages.create"} + ) + stream_settings = base.model_copy( + update={"name": base.name or "anthropic.Messages.stream"} + ) + async_stream_settings = base.model_copy( + update={"name": base.name or "anthropic.AsyncMessages.stream"} + ) + + _anthropic_patcher = MultiPatcher( + [ + # Patch the sync messages.create method for all messages.create methods + SymbolPatcher( + lambda: importlib.import_module("anthropic.resources.messages"), + "Messages.create", + create_wrapper_sync(messages_create_settings), + ), + SymbolPatcher( + lambda: importlib.import_module("anthropic.resources.messages"), + "AsyncMessages.create", + create_wrapper_async(async_messages_create_settings), + ), + SymbolPatcher( + lambda: importlib.import_module("anthropic.resources.messages"), + "Messages.stream", + create_stream_wrapper(stream_settings), + ), + SymbolPatcher( + lambda: importlib.import_module("anthropic.resources.messages"), + "AsyncMessages.stream", + create_stream_wrapper(async_stream_settings), + ), + ] + ) + + return _anthropic_patcher diff --git a/weave/integrations/cerebras/cerebras_sdk.py b/weave/integrations/cerebras/cerebras_sdk.py index bdce368290e..15fa1cf7935 100644 --- a/weave/integrations/cerebras/cerebras_sdk.py +++ b/weave/integrations/cerebras/cerebras_sdk.py @@ -1,25 +1,24 @@ import importlib from functools import wraps -from typing import Any, Callable +from typing import Any, Callable, Optional import weave +from weave.trace.autopatch import OpSettings from weave.trace.patcher import MultiPatcher, SymbolPatcher +_cerebras_patcher: Optional[MultiPatcher] = None -def create_wrapper_sync( - name: str, -) -> Callable[[Callable], Callable]: + +def create_wrapper_sync(settings: OpSettings) -> Callable[[Callable], Callable]: def wrapper(fn: Callable) -> Callable: - op = weave.op()(fn) - op.name = name # type: ignore + op_kwargs = settings.model_dump() + op = weave.op(fn, **op_kwargs) return op return wrapper -def create_wrapper_async( - name: str, -) -> Callable[[Callable], Callable]: +def create_wrapper_async(settings: OpSettings) -> Callable[[Callable], Callable]: def wrapper(fn: Callable) -> Callable: def _fn_wrapper(fn: Callable) -> Callable: @wraps(fn) @@ -28,24 +27,41 @@ async def _async_wrapper(*args: Any, **kwargs: Any) -> Any: return _async_wrapper - op = weave.op()(_fn_wrapper(fn)) - op.name = name # type: ignore + op_kwargs = settings.model_dump() + op = weave.op(_fn_wrapper(fn), **op_kwargs) return op return wrapper -cerebras_patcher = MultiPatcher( - [ - SymbolPatcher( - lambda: importlib.import_module("cerebras.cloud.sdk.resources.chat"), - "CompletionsResource.create", - create_wrapper_sync(name="cerebras.chat.completions.create"), - ), - SymbolPatcher( - lambda: importlib.import_module("cerebras.cloud.sdk.resources.chat"), - "AsyncCompletionsResource.create", - create_wrapper_async(name="cerebras.chat.completions.create"), - ), - ] -) +def get_cerebras_patcher(settings: Optional[OpSettings] = None) -> MultiPatcher: + global _cerebras_patcher + + if _cerebras_patcher is not None: + return _cerebras_patcher + + if settings is None: + settings = OpSettings() + + base = settings.op_settings + + create_settings = base.model_copy( + update={"name": base.name or "cerebras.chat.completions.create"} + ) + + _cerebras_patcher = MultiPatcher( + [ + SymbolPatcher( + lambda: importlib.import_module("cerebras.cloud.sdk.resources.chat"), + "CompletionsResource.create", + create_wrapper_sync(create_settings), + ), + SymbolPatcher( + lambda: importlib.import_module("cerebras.cloud.sdk.resources.chat"), + "AsyncCompletionsResource.create", + create_wrapper_async(create_settings), + ), + ] + ) + + return _cerebras_patcher diff --git a/weave/integrations/cohere/__init__.py b/weave/integrations/cohere/__init__.py index 45f925b6eea..288cce91aae 100644 --- a/weave/integrations/cohere/__init__.py +++ b/weave/integrations/cohere/__init__.py @@ -1 +1 @@ -from weave.integrations.cohere.cohere_sdk import cohere_patcher as cohere_patcher +from weave.integrations.cohere.cohere_sdk import get_cohere_patcher # noqa: F401 diff --git a/weave/integrations/cohere/cohere_sdk.py b/weave/integrations/cohere/cohere_sdk.py index b0a5944795b..aaaa7b10639 100644 --- a/weave/integrations/cohere/cohere_sdk.py +++ b/weave/integrations/cohere/cohere_sdk.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Any, Callable, Optional import weave +from weave.trace.autopatch import OpSettings from weave.trace.op_extensions.accumulator import add_accumulator from weave.trace.patcher import MultiPatcher, SymbolPatcher @@ -11,6 +12,9 @@ from cohere.v2.types.non_streamed_chat_response2 import NonStreamedChatResponse2 +_cohere_patcher: Optional[MultiPatcher] = None + + def cohere_accumulator( acc: Optional[dict], value: Any, @@ -86,16 +90,16 @@ def _accumulate_content( return acc -def cohere_wrapper(name: str) -> Callable: +def cohere_wrapper(settings: OpSettings) -> Callable: def wrapper(fn: Callable) -> Callable: - op = weave.op(fn) - op.name = name # type: ignore + op_kwargs = settings.model_dump() + op = weave.op(fn, **op_kwargs) return op return wrapper -def cohere_wrapper_v2(name: str) -> Callable: +def cohere_wrapper_v2(settings: OpSettings) -> Callable: def wrapper(fn: Callable) -> Callable: def _post_process_response(fn: Callable) -> Any: @wraps(fn) @@ -122,14 +126,14 @@ def _wrapper(*args: Any, **kwargs: Any) -> Any: return _wrapper - op = weave.op(_post_process_response(fn)) - op.name = name # type: ignore + op_kwargs = settings.model_dump() + op = weave.op(_post_process_response(fn), **op_kwargs) return op return wrapper -def cohere_wrapper_async_v2(name: str) -> Callable: +def cohere_wrapper_async_v2(settings: OpSettings) -> Callable: def wrapper(fn: Callable) -> Callable: def _post_process_response(fn: Callable) -> Any: @wraps(fn) @@ -156,81 +160,115 @@ async def _wrapper(*args: Any, **kwargs: Any) -> Any: return _wrapper - op = weave.op(_post_process_response(fn)) - op.name = name # type: ignore + op_kwargs = settings.model_dump() + op = weave.op(_post_process_response(fn), **op_kwargs) return op return wrapper -def cohere_stream_wrapper(name: str) -> Callable: +def cohere_stream_wrapper(settings: OpSettings) -> Callable: def wrapper(fn: Callable) -> Callable: - op = weave.op(fn) - op.name = name # type: ignore - return add_accumulator(op, lambda inputs: cohere_accumulator) # type: ignore + op_kwargs = settings.model_dump() + op = weave.op(fn, **op_kwargs) + return add_accumulator(op, lambda inputs: cohere_accumulator) return wrapper -def cohere_stream_wrapper_v2(name: str) -> Callable: +def cohere_stream_wrapper_v2(settings: OpSettings) -> Callable: def wrapper(fn: Callable) -> Callable: - op = weave.op(fn) - op.name = name # type: ignore - return add_accumulator( - op, make_accumulator=lambda inputs: cohere_accumulator_v2 - ) + op_kwargs = settings.model_dump() + op = weave.op(fn, **op_kwargs) + return add_accumulator(op, lambda inputs: cohere_accumulator_v2) return wrapper -cohere_patcher = MultiPatcher( - [ - SymbolPatcher( - lambda: importlib.import_module("cohere"), - "Client.chat", - cohere_wrapper("cohere.Client.chat"), - ), - # Patch the async chat method - SymbolPatcher( - lambda: importlib.import_module("cohere"), - "AsyncClient.chat", - cohere_wrapper("cohere.AsyncClient.chat"), - ), - # Add patch for chat_stream method - SymbolPatcher( - lambda: importlib.import_module("cohere"), - "Client.chat_stream", - cohere_stream_wrapper("cohere.Client.chat_stream"), - ), - # Add patch for async chat_stream method - SymbolPatcher( - lambda: importlib.import_module("cohere"), - "AsyncClient.chat_stream", - cohere_stream_wrapper("cohere.AsyncClient.chat_stream"), - ), - # Add patch for cohere v2 - SymbolPatcher( - lambda: importlib.import_module("cohere"), - "ClientV2.chat", - cohere_wrapper_v2("cohere.ClientV2.chat"), - ), - # Add patch for cohre v2 async chat method - SymbolPatcher( - lambda: importlib.import_module("cohere"), - "AsyncClientV2.chat", - cohere_wrapper_async_v2("cohere.AsyncClientV2.chat"), - ), - # Add patch for chat_stream method v2 - SymbolPatcher( - lambda: importlib.import_module("cohere"), - "ClientV2.chat_stream", - cohere_stream_wrapper_v2("cohere.ClientV2.chat_stream"), - ), - # Add patch for async chat_stream method v2 - SymbolPatcher( - lambda: importlib.import_module("cohere"), - "AsyncClientV2.chat_stream", - cohere_stream_wrapper_v2("cohere.AsyncClientV2.chat_stream"), - ), - ] -) +def get_cohere_patcher(settings: Optional[OpSettings] = None) -> MultiPatcher: + global _cohere_patcher + + if _cohere_patcher is not None: + return _cohere_patcher + + if settings is None: + settings = OpSettings() + + base = settings.op_settings + + chat_settings = base.model_copy(update={"name": base.name or "cohere.Client.chat"}) + async_chat_settings = base.model_copy( + update={"name": base.name or "cohere.AsyncClient.chat"} + ) + chat_stream_settings = base.model_copy( + update={"name": base.name or "cohere.Client.chat_stream"} + ) + async_chat_stream_settings = base.model_copy( + update={"name": base.name or "cohere.AsyncClient.chat_stream"} + ) + chat_v2_settings = base.model_copy( + update={"name": base.name or "cohere.ClientV2.chat"} + ) + async_chat_v2_settings = base.model_copy( + update={"name": base.name or "cohere.AsyncClientV2.chat"} + ) + chat_stream_v2_settings = base.model_copy( + update={"name": base.name or "cohere.ClientV2.chat_stream"} + ) + async_chat_stream_v2_settings = base.model_copy( + update={"name": base.name or "cohere.AsyncClientV2.chat_stream"} + ) + + _cohere_patcher = MultiPatcher( + [ + SymbolPatcher( + lambda: importlib.import_module("cohere"), + "Client.chat", + cohere_wrapper(chat_settings), + ), + # Patch the async chat method + SymbolPatcher( + lambda: importlib.import_module("cohere"), + "AsyncClient.chat", + cohere_wrapper(async_chat_settings), + ), + # Add patch for chat_stream method + SymbolPatcher( + lambda: importlib.import_module("cohere"), + "Client.chat_stream", + cohere_stream_wrapper(chat_stream_settings), + ), + # Add patch for async chat_stream method + SymbolPatcher( + lambda: importlib.import_module("cohere"), + "AsyncClient.chat_stream", + cohere_stream_wrapper(async_chat_stream_settings), + ), + # Add patch for cohere v2 + SymbolPatcher( + lambda: importlib.import_module("cohere"), + "ClientV2.chat", + cohere_wrapper_v2(chat_v2_settings), + ), + # Add patch for cohre v2 async chat method + SymbolPatcher( + lambda: importlib.import_module("cohere"), + "AsyncClientV2.chat", + cohere_wrapper_async_v2(async_chat_v2_settings), + ), + # Add patch for chat_stream method v2 + SymbolPatcher( + lambda: importlib.import_module("cohere"), + "ClientV2.chat_stream", + cohere_stream_wrapper_v2(chat_stream_v2_settings), + ), + # Add patch for async chat_stream method v2 + SymbolPatcher( + lambda: importlib.import_module("cohere"), + "AsyncClientV2.chat_stream", + cohere_stream_wrapper_v2(async_chat_stream_v2_settings), + ), + ] + ) + + return _cohere_patcher diff --git a/weave/integrations/google_ai_studio/google_ai_studio_sdk.py b/weave/integrations/google_ai_studio/google_ai_studio_sdk.py index 2cd8a2fe137..68c2a0b39ea 100644 --- a/weave/integrations/google_ai_studio/google_ai_studio_sdk.py +++ b/weave/integrations/google_ai_studio/google_ai_studio_sdk.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Any, Callable, Optional import weave +from weave.trace.autopatch import OpSettings from weave.trace.op_extensions.accumulator import add_accumulator from weave.trace.patcher import MultiPatcher, SymbolPatcher from weave.trace.serialize import dictify @@ -11,6 +12,8 @@ if TYPE_CHECKING: from google.generativeai.types.generation_types import GenerateContentResponse +_google_genai_patcher: Optional[MultiPatcher] = None + def gemini_postprocess_inputs(inputs: dict[str, Any]) -> dict[str, Any]: if "self" in inputs: @@ -89,10 +92,13 @@ def gemini_on_finish( call.summary.update(summary_update) -def gemini_wrapper_sync(name: str) -> Callable[[Callable], Callable]: +def gemini_wrapper_sync(settings: OpSettings) -> Callable[[Callable], Callable]: def wrapper(fn: Callable) -> Callable: - op = weave.op(postprocess_inputs=gemini_postprocess_inputs)(fn) - op.name = name # type: ignore + op_kwargs = settings.model_dump() + if not op_kwargs.get("postprocess_inputs"): + op_kwargs["postprocess_inputs"] = gemini_postprocess_inputs + + op = weave.op(fn, **op_kwargs) op._set_on_finish_handler(gemini_on_finish) return add_accumulator( op, # type: ignore @@ -104,7 +110,7 @@ def wrapper(fn: Callable) -> Callable: return wrapper -def gemini_wrapper_async(name: str) -> Callable[[Callable], Callable]: +def gemini_wrapper_async(settings: OpSettings) -> Callable[[Callable], Callable]: def wrapper(fn: Callable) -> Callable: def _fn_wrapper(fn: Callable) -> Callable: @wraps(fn) @@ -113,9 +119,11 @@ async def _async_wrapper(*args: Any, **kwargs: Any) -> Any: return _async_wrapper - "We need to do this so we can check if `stream` is used" - op = weave.op(postprocess_inputs=gemini_postprocess_inputs)(_fn_wrapper(fn)) - op.name = name # type: ignore + op_kwargs = settings.model_dump() + if not op_kwargs.get("postprocess_inputs"): + op_kwargs["postprocess_inputs"] = gemini_postprocess_inputs + + op = weave.op(_fn_wrapper(fn), **op_kwargs) op._set_on_finish_handler(gemini_on_finish) return add_accumulator( op, # type: ignore @@ -127,33 +135,68 @@ async def _async_wrapper(*args: Any, **kwargs: Any) -> Any: return wrapper -google_genai_patcher = MultiPatcher( - [ - SymbolPatcher( - lambda: importlib.import_module("google.generativeai.generative_models"), - "GenerativeModel.generate_content", - gemini_wrapper_sync( - name="google.generativeai.GenerativeModel.generate_content" +def get_google_genai_patcher(settings: Optional[OpSettings] = None) -> MultiPatcher: + global _google_genai_patcher + + if _google_genai_patcher is not None: + return _google_genai_patcher + + if settings is None: + settings = OpSettings() + + base = settings.op_settings + + generate_content_settings = base.model_copy( + update={ + "name": base.name or "google.generativeai.GenerativeModel.generate_content" + } + ) + generate_content_async_settings = base.model_copy( + update={ + "name": base.name + or "google.generativeai.GenerativeModel.generate_content_async" + } + ) + send_message_settings = base.model_copy( + update={"name": base.name or "google.generativeai.ChatSession.send_message"} + ) + send_message_async_settings = base.model_copy( + update={ + "name": base.name or "google.generativeai.ChatSession.send_message_async" + } + ) + + _google_genai_patcher = MultiPatcher( + [ + SymbolPatcher( + lambda: importlib.import_module( + "google.generativeai.generative_models" + ), + "GenerativeModel.generate_content", + gemini_wrapper_sync(generate_content_settings), ), - ), - SymbolPatcher( - lambda: importlib.import_module("google.generativeai.generative_models"), - "GenerativeModel.generate_content_async", - gemini_wrapper_async( - name="google.generativeai.GenerativeModel.generate_content_async" + SymbolPatcher( + lambda: importlib.import_module( + "google.generativeai.generative_models" + ), + "GenerativeModel.generate_content_async", + gemini_wrapper_async(generate_content_async_settings), ), - ), - SymbolPatcher( - lambda: importlib.import_module("google.generativeai.generative_models"), - "ChatSession.send_message", - gemini_wrapper_sync(name="google.generativeai.ChatSession.send_message"), - ), - SymbolPatcher( - lambda: importlib.import_module("google.generativeai.generative_models"), - "ChatSession.send_message_async", - gemini_wrapper_async( - name="google.generativeai.ChatSession.send_message_async" + SymbolPatcher( + lambda: importlib.import_module( + "google.generativeai.generative_models" + ), + "ChatSession.send_message", + gemini_wrapper_sync(send_message_settings), ), - ), - ] -) + SymbolPatcher( + lambda: importlib.import_module( + "google.generativeai.generative_models" + ), + "ChatSession.send_message_async", + gemini_wrapper_async(send_message_async_settings), + ), + ] + ) + + return _google_genai_patcher diff --git a/weave/integrations/groq/groq_sdk.py b/weave/integrations/groq/groq_sdk.py index 4f470e6d743..2de5462da90 100644 --- a/weave/integrations/groq/groq_sdk.py +++ b/weave/integrations/groq/groq_sdk.py @@ -1,6 +1,8 @@ import importlib from typing import TYPE_CHECKING, Callable, Optional +from weave.trace.autopatch import OpSettings + if TYPE_CHECKING: from groq.types.chat import ChatCompletion, ChatCompletionChunk @@ -8,6 +10,8 @@ from weave.trace.op_extensions.accumulator import add_accumulator from weave.trace.patcher import MultiPatcher, SymbolPatcher +_groq_patcher: Optional[MultiPatcher] = None + def groq_accumulator( acc: Optional["ChatCompletion"], value: "ChatCompletionChunk" @@ -83,11 +87,10 @@ def should_use_accumulator(inputs: dict) -> bool: return isinstance(inputs, dict) and bool(inputs.get("stream")) -def groq_wrapper(name: str) -> Callable[[Callable], Callable]: +def groq_wrapper(settings: OpSettings) -> Callable[[Callable], Callable]: def wrapper(fn: Callable) -> Callable: - op = weave.op()(fn) - op.name = name # type: ignore - # return op + op_kwargs = settings.model_dump() + op = weave.op(fn, **op_kwargs) return add_accumulator( op, # type: ignore make_accumulator=lambda inputs: groq_accumulator, @@ -97,17 +100,37 @@ def wrapper(fn: Callable) -> Callable: return wrapper -groq_patcher = MultiPatcher( - [ - SymbolPatcher( - lambda: importlib.import_module("groq.resources.chat.completions"), - "Completions.create", - groq_wrapper(name="groq.chat.completions.create"), - ), - SymbolPatcher( - lambda: importlib.import_module("groq.resources.chat.completions"), - "AsyncCompletions.create", - groq_wrapper(name="groq.async.chat.completions.create"), - ), - ] -) +def get_groq_patcher(settings: Optional[OpSettings] = None) -> MultiPatcher: + global _groq_patcher + + if _groq_patcher is not None: + return _groq_patcher + + if settings is None: + settings = OpSettings() + + base = settings.op_settings + + chat_completions_settings = base.model_copy( + update={"name": base.name or "groq.chat.completions.create"} + ) + async_chat_completions_settings = base.model_copy( + update={"name": base.name or "groq.async.chat.completions.create"} + ) + + _groq_patcher = MultiPatcher( + [ + SymbolPatcher( + lambda: importlib.import_module("groq.resources.chat.completions"), + "Completions.create", + groq_wrapper(chat_completions_settings), + ), + SymbolPatcher( + lambda: importlib.import_module("groq.resources.chat.completions"), + "AsyncCompletions.create", + groq_wrapper(async_chat_completions_settings), + ), + ] + ) + + return _groq_patcher From da6fb3958aa8a1945dd50beaa72b5839d964f691 Mon Sep 17 00:00:00 2001 From: Andrew Truong Date: Wed, 11 Dec 2024 18:09:40 -0500 Subject: [PATCH 2/5] test --- weave/integrations/cerebras/cerebras_sdk.py | 8 +- weave/integrations/cohere/cohere_sdk.py | 6 +- .../google_ai_studio/google_ai_studio_sdk.py | 8 +- weave/integrations/groq/groq_sdk.py | 6 +- .../instructor/instructor_iterable_utils.py | 14 +- .../instructor/instructor_partial_utils.py | 7 +- .../integrations/instructor/instructor_sdk.py | 79 +++++++---- weave/integrations/litellm/litellm.py | 57 +++++--- weave/integrations/mistral/__init__.py | 4 +- weave/integrations/mistral/v0/mistral.py | 106 +++++++++----- weave/integrations/mistral/v1/mistral.py | 96 ++++++++----- weave/integrations/notdiamond/__init__.py | 2 +- weave/integrations/notdiamond/tracing.py | 131 +++++++++++++----- weave/integrations/vertexai/vertexai_sdk.py | 112 ++++++++++----- weave/trace/autopatch.py | 121 ++++++++-------- 15 files changed, 490 insertions(+), 267 deletions(-) diff --git a/weave/integrations/cerebras/cerebras_sdk.py b/weave/integrations/cerebras/cerebras_sdk.py index 15fa1cf7935..1b6fee49006 100644 --- a/weave/integrations/cerebras/cerebras_sdk.py +++ b/weave/integrations/cerebras/cerebras_sdk.py @@ -3,7 +3,7 @@ from typing import Any, Callable, Optional import weave -from weave.trace.autopatch import OpSettings +from weave.trace.autopatch import IntegrationSettings, OpSettings from weave.trace.patcher import MultiPatcher, SymbolPatcher _cerebras_patcher: Optional[MultiPatcher] = None @@ -34,14 +34,16 @@ async def _async_wrapper(*args: Any, **kwargs: Any) -> Any: return wrapper -def get_cerebras_patcher(settings: Optional[OpSettings] = None) -> MultiPatcher: +def get_cerebras_patcher( + settings: Optional[IntegrationSettings] = None, +) -> MultiPatcher: global _cerebras_patcher if _cerebras_patcher is not None: return _cerebras_patcher if settings is None: - settings = OpSettings() + settings = IntegrationSettings() base = settings.op_settings diff --git a/weave/integrations/cohere/cohere_sdk.py b/weave/integrations/cohere/cohere_sdk.py index aaaa7b10639..e554e659478 100644 --- a/weave/integrations/cohere/cohere_sdk.py +++ b/weave/integrations/cohere/cohere_sdk.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Any, Callable, Optional import weave -from weave.trace.autopatch import OpSettings +from weave.trace.autopatch import IntegrationSettings, OpSettings from weave.trace.op_extensions.accumulator import add_accumulator from weave.trace.patcher import MultiPatcher, SymbolPatcher @@ -185,14 +185,14 @@ def wrapper(fn: Callable) -> Callable: return wrapper -def get_cohere_patcher(settings: Optional[OpSettings] = None) -> MultiPatcher: +def get_cohere_patcher(settings: Optional[IntegrationSettings] = None) -> MultiPatcher: global _cohere_patcher if _cohere_patcher is not None: return _cohere_patcher if settings is None: - settings = OpSettings() + settings = IntegrationSettings() base = settings.op_settings diff --git a/weave/integrations/google_ai_studio/google_ai_studio_sdk.py b/weave/integrations/google_ai_studio/google_ai_studio_sdk.py index 68c2a0b39ea..d86fe5b11f3 100644 --- a/weave/integrations/google_ai_studio/google_ai_studio_sdk.py +++ b/weave/integrations/google_ai_studio/google_ai_studio_sdk.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Any, Callable, Optional import weave -from weave.trace.autopatch import OpSettings +from weave.trace.autopatch import IntegrationSettings, OpSettings from weave.trace.op_extensions.accumulator import add_accumulator from weave.trace.patcher import MultiPatcher, SymbolPatcher from weave.trace.serialize import dictify @@ -135,14 +135,16 @@ async def _async_wrapper(*args: Any, **kwargs: Any) -> Any: return wrapper -def get_google_genai_patcher(settings: Optional[OpSettings] = None) -> MultiPatcher: +def get_google_genai_patcher( + settings: Optional[IntegrationSettings] = None, +) -> MultiPatcher: global _google_genai_patcher if _google_genai_patcher is not None: return _google_genai_patcher if settings is None: - settings = OpSettings() + settings = IntegrationSettings() base = settings.op_settings diff --git a/weave/integrations/groq/groq_sdk.py b/weave/integrations/groq/groq_sdk.py index 2de5462da90..fe8db33bcf5 100644 --- a/weave/integrations/groq/groq_sdk.py +++ b/weave/integrations/groq/groq_sdk.py @@ -1,7 +1,7 @@ import importlib from typing import TYPE_CHECKING, Callable, Optional -from weave.trace.autopatch import OpSettings +from weave.trace.autopatch import IntegrationSettings, OpSettings if TYPE_CHECKING: from groq.types.chat import ChatCompletion, ChatCompletionChunk @@ -100,14 +100,14 @@ def wrapper(fn: Callable) -> Callable: return wrapper -def get_groq_patcher(settings: Optional[OpSettings] = None) -> MultiPatcher: +def get_groq_patcher(settings: Optional[IntegrationSettings] = None) -> MultiPatcher: global _groq_patcher if _groq_patcher is not None: return _groq_patcher if settings is None: - settings = OpSettings() + settings = IntegrationSettings() base = settings.op_settings diff --git a/weave/integrations/instructor/instructor_iterable_utils.py b/weave/integrations/instructor/instructor_iterable_utils.py index 84d64a103b6..3b0f128a132 100644 --- a/weave/integrations/instructor/instructor_iterable_utils.py +++ b/weave/integrations/instructor/instructor_iterable_utils.py @@ -4,6 +4,7 @@ from pydantic import BaseModel import weave +from weave.trace.autopatch import OpSettings from weave.trace.op_extensions.accumulator import add_accumulator @@ -27,10 +28,10 @@ def should_accumulate_iterable(inputs: dict) -> bool: return False -def instructor_wrapper_sync(name: str) -> Callable[[Callable], Callable]: +def instructor_wrapper_sync(settings: OpSettings) -> Callable[[Callable], Callable]: def wrapper(fn: Callable) -> Callable: - op = weave.op(fn) - op.name = name # type: ignore + op_kwargs = settings.model_dump() + op = weave.op(fn, **op_kwargs) return add_accumulator( op, # type: ignore make_accumulator=lambda inputs: instructor_iterable_accumulator, @@ -40,7 +41,7 @@ def wrapper(fn: Callable) -> Callable: return wrapper -def instructor_wrapper_async(name: str) -> Callable[[Callable], Callable]: +def instructor_wrapper_async(settings: OpSettings) -> Callable[[Callable], Callable]: def wrapper(fn: Callable) -> Callable: def _fn_wrapper(fn: Callable) -> Callable: @wraps(fn) @@ -49,9 +50,8 @@ async def _async_wrapper(*args: Any, **kwargs: Any) -> Any: return _async_wrapper - "We need to do this so we can check if `stream` is used" - op = weave.op(_fn_wrapper(fn)) - op.name = name # type: ignore + op_kwargs = settings.model_dump() + op = weave.op(_fn_wrapper(fn), **op_kwargs) return add_accumulator( op, # type: ignore make_accumulator=lambda inputs: instructor_iterable_accumulator, diff --git a/weave/integrations/instructor/instructor_partial_utils.py b/weave/integrations/instructor/instructor_partial_utils.py index f90dc7edb17..8efa84b302f 100644 --- a/weave/integrations/instructor/instructor_partial_utils.py +++ b/weave/integrations/instructor/instructor_partial_utils.py @@ -3,6 +3,7 @@ from pydantic import BaseModel import weave +from weave.trace.autopatch import OpSettings from weave.trace.op_extensions.accumulator import add_accumulator @@ -14,10 +15,10 @@ def instructor_partial_accumulator( return acc -def instructor_wrapper_partial(name: str) -> Callable[[Callable], Callable]: +def instructor_wrapper_partial(settings: OpSettings) -> Callable[[Callable], Callable]: def wrapper(fn: Callable) -> Callable: - op = weave.op()(fn) - op.name = name # type: ignore + op_kwargs = settings.model_dump() + op = weave.op(fn, **op_kwargs) return add_accumulator( op, # type: ignore make_accumulator=lambda inputs: instructor_partial_accumulator, diff --git a/weave/integrations/instructor/instructor_sdk.py b/weave/integrations/instructor/instructor_sdk.py index 867e9f2a785..2716d570cbb 100644 --- a/weave/integrations/instructor/instructor_sdk.py +++ b/weave/integrations/instructor/instructor_sdk.py @@ -1,31 +1,62 @@ import importlib +from typing import Optional +from weave.trace.autopatch import IntegrationSettings from weave.trace.patcher import MultiPatcher, SymbolPatcher from .instructor_iterable_utils import instructor_wrapper_async, instructor_wrapper_sync from .instructor_partial_utils import instructor_wrapper_partial -instructor_patcher = MultiPatcher( - [ - SymbolPatcher( - lambda: importlib.import_module("instructor.client"), - "Instructor.create", - instructor_wrapper_sync(name="Instructor.create"), - ), - SymbolPatcher( - lambda: importlib.import_module("instructor.client"), - "AsyncInstructor.create", - instructor_wrapper_async(name="AsyncInstructor.create"), - ), - SymbolPatcher( - lambda: importlib.import_module("instructor.client"), - "Instructor.create_partial", - instructor_wrapper_partial(name="Instructor.create_partial"), - ), - SymbolPatcher( - lambda: importlib.import_module("instructor.client"), - "AsyncInstructor.create_partial", - instructor_wrapper_partial(name="AsyncInstructor.create_partial"), - ), - ] -) +_instructor_patcher: Optional[MultiPatcher] = None + + +def get_instructor_patcher( + settings: Optional[IntegrationSettings] = None, +) -> MultiPatcher: + global _instructor_patcher + + if _instructor_patcher is not None: + return _instructor_patcher + + if settings is None: + settings = IntegrationSettings() + + base = settings.op_settings + + create_settings = base.model_copy(update={"name": base.name or "Instructor.create"}) + async_create_settings = base.model_copy( + update={"name": base.name or "AsyncInstructor.create"} + ) + create_partial_settings = base.model_copy( + update={"name": base.name or "Instructor.create_partial"} + ) + async_create_partial_settings = base.model_copy( + update={"name": base.name or "AsyncInstructor.create_partial"} + ) + + _instructor_patcher = MultiPatcher( + [ + SymbolPatcher( + lambda: importlib.import_module("instructor.client"), + "Instructor.create", + instructor_wrapper_sync(create_settings), + ), + SymbolPatcher( + lambda: importlib.import_module("instructor.client"), + "AsyncInstructor.create", + instructor_wrapper_async(async_create_settings), + ), + SymbolPatcher( + lambda: importlib.import_module("instructor.client"), + "Instructor.create_partial", + instructor_wrapper_partial(create_partial_settings), + ), + SymbolPatcher( + lambda: importlib.import_module("instructor.client"), + "AsyncInstructor.create_partial", + instructor_wrapper_partial(async_create_partial_settings), + ), + ] + ) + + return _instructor_patcher diff --git a/weave/integrations/litellm/litellm.py b/weave/integrations/litellm/litellm.py index c3bf1bf114a..8a1818820a0 100644 --- a/weave/integrations/litellm/litellm.py +++ b/weave/integrations/litellm/litellm.py @@ -2,12 +2,15 @@ from typing import TYPE_CHECKING, Any, Callable, Optional import weave +from weave.trace.autopatch import IntegrationSettings, OpSettings from weave.trace.op_extensions.accumulator import add_accumulator from weave.trace.patcher import MultiPatcher, SymbolPatcher if TYPE_CHECKING: from litellm.utils import ModelResponse +_litellm_patcher: Optional[MultiPatcher] = None + # This accumulator is nearly identical to the mistral accumulator, just with different types. def litellm_accumulator( @@ -82,10 +85,10 @@ def should_use_accumulator(inputs: dict) -> bool: return isinstance(inputs, dict) and bool(inputs.get("stream")) -def make_wrapper(name: str) -> Callable: +def make_wrapper(settings: OpSettings) -> Callable: def litellm_wrapper(fn: Callable) -> Callable: - op = weave.op()(fn) - op.name = name # type: ignore + op_kwargs = settings.model_dump() + op = weave.op(fn, **op_kwargs) return add_accumulator( op, # type: ignore make_accumulator=lambda inputs: litellm_accumulator, @@ -96,17 +99,37 @@ def litellm_wrapper(fn: Callable) -> Callable: return litellm_wrapper -litellm_patcher = MultiPatcher( - [ - SymbolPatcher( - lambda: importlib.import_module("litellm"), - "completion", - make_wrapper("litellm.completion"), - ), - SymbolPatcher( - lambda: importlib.import_module("litellm"), - "acompletion", - make_wrapper("litellm.acompletion"), - ), - ] -) +def get_litellm_patcher(settings: Optional[IntegrationSettings] = None) -> MultiPatcher: + global _litellm_patcher + + if _litellm_patcher is not None: + return _litellm_patcher + + if settings is None: + settings = IntegrationSettings() + + base = settings.op_settings + + completion_settings = base.model_copy( + update={"name": base.name or "litellm.completion"} + ) + acompletion_settings = base.model_copy( + update={"name": base.name or "litellm.acompletion"} + ) + + _litellm_patcher = MultiPatcher( + [ + SymbolPatcher( + lambda: importlib.import_module("litellm"), + "completion", + make_wrapper(completion_settings), + ), + SymbolPatcher( + lambda: importlib.import_module("litellm"), + "acompletion", + make_wrapper(acompletion_settings), + ), + ] + ) + + return _litellm_patcher diff --git a/weave/integrations/mistral/__init__.py b/weave/integrations/mistral/__init__.py index 34b40835efc..d78812d9afa 100644 --- a/weave/integrations/mistral/__init__.py +++ b/weave/integrations/mistral/__init__.py @@ -8,10 +8,10 @@ mistral_version = "1.0" # we need to return a patching function if version.parse(mistral_version) < version.parse("1.0.0"): - from .v0.mistral import mistral_patcher + from .v0.mistral import get_mistral_patcher # noqa: F401 print( f"Using MistralAI version {mistral_version}. Please consider upgrading to version 1.0.0 or later." ) else: - from .v1.mistral import mistral_patcher # noqa: F401 + from .v1.mistral import get_mistral_patcher # noqa: F401 diff --git a/weave/integrations/mistral/v0/mistral.py b/weave/integrations/mistral/v0/mistral.py index 6d4915eab41..70c81f449e2 100644 --- a/weave/integrations/mistral/v0/mistral.py +++ b/weave/integrations/mistral/v0/mistral.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING, Callable, Optional import weave +from weave.trace.autopatch import IntegrationSettings, OpSettings from weave.trace.op_extensions.accumulator import add_accumulator from weave.trace.patcher import MultiPatcher, SymbolPatcher @@ -11,6 +12,8 @@ ChatCompletionStreamResponse, ) +_mistral_patcher: Optional[MultiPatcher] = None + def mistral_accumulator( acc: Optional["ChatCompletionResponse"], @@ -72,37 +75,72 @@ def mistral_accumulator( return acc -def mistral_stream_wrapper(fn: Callable) -> Callable: - op = weave.op()(fn) - acc_op = add_accumulator(op, lambda inputs: mistral_accumulator) # type: ignore - return acc_op - - -mistral_patcher = MultiPatcher( - [ - # Patch the sync, non-streaming chat method - SymbolPatcher( - lambda: importlib.import_module("mistralai.client"), - "MistralClient.chat", - weave.op(), - ), - # Patch the sync, streaming chat method - SymbolPatcher( - lambda: importlib.import_module("mistralai.client"), - "MistralClient.chat_stream", - mistral_stream_wrapper, - ), - # Patch the async, non-streaming chat method - SymbolPatcher( - lambda: importlib.import_module("mistralai.async_client"), - "MistralAsyncClient.chat", - weave.op(), - ), - # Patch the async, streaming chat method - SymbolPatcher( - lambda: importlib.import_module("mistralai.async_client"), - "MistralAsyncClient.chat_stream", - mistral_stream_wrapper, - ), - ] -) +def mistral_stream_wrapper(settings: OpSettings) -> Callable: + def wrapper(fn: Callable) -> Callable: + op_kwargs = settings.model_dump() + op = weave.op(fn, **op_kwargs) + acc_op = add_accumulator(op, lambda inputs: mistral_accumulator) # type: ignore + return acc_op + + return wrapper + + +def mistral_wrapper(settings: OpSettings) -> Callable: + def wrapper(fn: Callable) -> Callable: + op_kwargs = settings.model_dump() + op = weave.op(fn, **op_kwargs) + return op + + return wrapper + + +def get_mistral_patcher(settings: Optional[IntegrationSettings] = None) -> MultiPatcher: + global _mistral_patcher + + if _mistral_patcher is not None: + return _mistral_patcher + + if settings is None: + settings = IntegrationSettings() + + base = settings.op_settings + + chat_settings = base.model_copy(update={"name": base.name or "mistralai.chat"}) + chat_stream_settings = base.model_copy( + update={"name": base.name or "mistralai.chat_stream"} + ) + async_chat_settings = base.model_copy( + update={"name": base.name or "mistralai.async_client.chat"} + ) + async_chat_stream_settings = base.model_copy( + update={"name": base.name or "mistralai.async_client.chat_stream"} + ) + + mistral_patcher = MultiPatcher( + [ + # Patch the sync, non-streaming chat method + SymbolPatcher( + lambda: importlib.import_module("mistralai.client"), + "MistralClient.chat", + mistral_wrapper(chat_settings), + ), + # Patch the sync, streaming chat method + SymbolPatcher( + lambda: importlib.import_module("mistralai.client"), + "MistralClient.chat_stream", + mistral_stream_wrapper(chat_stream_settings), + ), + # Patch the async, non-streaming chat method + SymbolPatcher( + lambda: importlib.import_module("mistralai.async_client"), + "MistralAsyncClient.chat", + mistral_wrapper(async_chat_settings), + ), + # Patch the async, streaming chat method + SymbolPatcher( + lambda: importlib.import_module("mistralai.async_client"), + "MistralAsyncClient.chat_stream", + mistral_stream_wrapper(async_chat_stream_settings), + ), + ] + ) diff --git a/weave/integrations/mistral/v1/mistral.py b/weave/integrations/mistral/v1/mistral.py index 692aa7b159b..a7824803c40 100644 --- a/weave/integrations/mistral/v1/mistral.py +++ b/weave/integrations/mistral/v1/mistral.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING, Callable, Optional import weave +from weave.trace.autopatch import IntegrationSettings, OpSettings from weave.trace.op_extensions.accumulator import add_accumulator from weave.trace.patcher import MultiPatcher, SymbolPatcher @@ -11,6 +12,8 @@ CompletionEvent, ) +_mistral_patcher: Optional[MultiPatcher] = None + def mistral_accumulator( acc: Optional["ChatCompletionResponse"], @@ -79,50 +82,75 @@ def mistral_accumulator( return acc -def mistral_stream_wrapper(name: str) -> Callable: +def mistral_stream_wrapper(settings: OpSettings) -> Callable: def wrapper(fn: Callable) -> Callable: - op = weave.op()(fn) + op_kwargs = settings.model_dump() + op = weave.op(fn, **op_kwargs) acc_op = add_accumulator(op, lambda inputs: mistral_accumulator) # type: ignore - acc_op.name = name # type: ignore return acc_op return wrapper -def mistral_wrapper(name: str) -> Callable: +def mistral_wrapper(settings: OpSettings) -> Callable: def wrapper(fn: Callable) -> Callable: - op = weave.op()(fn) - op.name = name # type: ignore + op_kwargs = settings.model_dump() + op = weave.op(fn, **op_kwargs) return op return wrapper -mistral_patcher = MultiPatcher( - [ - # Patch the sync, non-streaming chat method - SymbolPatcher( - lambda: importlib.import_module("mistralai.chat"), - "Chat.complete", - mistral_wrapper(name="Mistral.chat.complete"), - ), - # Patch the sync, streaming chat method - SymbolPatcher( - lambda: importlib.import_module("mistralai.chat"), - "Chat.stream", - mistral_stream_wrapper(name="Mistral.chat.stream"), - ), - # Patch the async, non-streaming chat method - SymbolPatcher( - lambda: importlib.import_module("mistralai.chat"), - "Chat.complete_async", - mistral_wrapper(name="Mistral.chat.complete_async"), - ), - # Patch the async, streaming chat method - SymbolPatcher( - lambda: importlib.import_module("mistralai.chat"), - "Chat.stream_async", - mistral_stream_wrapper(name="Mistral.chat.stream_async"), - ), - ] -) +def get_mistral_patcher(settings: Optional[IntegrationSettings] = None) -> MultiPatcher: + global _mistral_patcher + + if _mistral_patcher is not None: + return _mistral_patcher + + if settings is None: + settings = IntegrationSettings() + + base = settings.op_settings + chat_complete_settings = base.model_copy( + update={"name": base.name or "mistralai.chat.complete"} + ) + chat_stream_settings = base.model_copy( + update={"name": base.name or "mistralai.chat.stream"} + ) + async_chat_complete_settings = base.model_copy( + update={"name": base.name or "mistralai.async_client.chat.complete"} + ) + async_chat_stream_settings = base.model_copy( + update={"name": base.name or "mistralai.async_client.chat.stream"} + ) + + _mistral_patcher = MultiPatcher( + [ + # Patch the sync, non-streaming chat method + SymbolPatcher( + lambda: importlib.import_module("mistralai.chat"), + "Chat.complete", + mistral_wrapper(chat_complete_settings), + ), + # Patch the sync, streaming chat method + SymbolPatcher( + lambda: importlib.import_module("mistralai.chat"), + "Chat.stream", + mistral_stream_wrapper(chat_stream_settings), + ), + # Patch the async, non-streaming chat method + SymbolPatcher( + lambda: importlib.import_module("mistralai.chat"), + "Chat.complete_async", + mistral_wrapper(async_chat_complete_settings), + ), + # Patch the async, streaming chat method + SymbolPatcher( + lambda: importlib.import_module("mistralai.chat"), + "Chat.stream_async", + mistral_stream_wrapper(async_chat_stream_settings), + ), + ] + ) + + return _mistral_patcher diff --git a/weave/integrations/notdiamond/__init__.py b/weave/integrations/notdiamond/__init__.py index d99c31c4176..8cb72ef2a55 100644 --- a/weave/integrations/notdiamond/__init__.py +++ b/weave/integrations/notdiamond/__init__.py @@ -1 +1 @@ -from .tracing import notdiamond_patcher as notdiamond_patcher +from .tracing import get_notdiamond_patcher # noqa: F401 diff --git a/weave/integrations/notdiamond/tracing.py b/weave/integrations/notdiamond/tracing.py index 90c08b9e8c6..ae1e5217725 100644 --- a/weave/integrations/notdiamond/tracing.py +++ b/weave/integrations/notdiamond/tracing.py @@ -1,19 +1,30 @@ import importlib -from typing import Callable +from typing import Callable, Optional import weave +from weave.trace.autopatch import IntegrationSettings, OpSettings from weave.trace.patcher import MultiPatcher, SymbolPatcher +_notdiamond_patcher: Optional[MultiPatcher] = None -def nd_wrapper(name: str) -> Callable[[Callable], Callable]: + +def nd_wrapper(settings: OpSettings) -> Callable[[Callable], Callable]: def wrapper(fn: Callable) -> Callable: - op = weave.op()(fn) - op.name = name # type: ignore + op_kwargs = settings.model_dump() + op = weave.op(fn, **op_kwargs) return op return wrapper +def passthrough_wrapper(settings: OpSettings) -> Callable: + def wrapper(fn: Callable) -> Callable: + op_kwargs = settings.model_dump() + return weave.op(fn, **op_kwargs) + + return wrapper + + def _patch_client_op(method_name: str) -> list[SymbolPatcher]: return [ SymbolPatcher( @@ -29,36 +40,82 @@ def _patch_client_op(method_name: str) -> list[SymbolPatcher]: ] -patched_client_functions = _patch_client_op("model_select") - -patched_llmconfig_functions = [ - SymbolPatcher( - lambda: importlib.import_module("notdiamond"), - "LLMConfig.__init__", - weave.op(), - ), - SymbolPatcher( - lambda: importlib.import_module("notdiamond"), - "LLMConfig.from_string", - weave.op(), - ), -] - -patched_toolkit_functions = [ - SymbolPatcher( - lambda: importlib.import_module("notdiamond.toolkit.custom_router"), - "CustomRouter.fit", - weave.op(), - ), - SymbolPatcher( - lambda: importlib.import_module("notdiamond.toolkit.custom_router"), - "CustomRouter.eval", - weave.op(), - ), -] - -all_patched_functions = ( - patched_client_functions + patched_toolkit_functions + patched_llmconfig_functions -) - -notdiamond_patcher = MultiPatcher(all_patched_functions) +def get_notdiamond_patcher( + settings: Optional[IntegrationSettings] = None, +) -> MultiPatcher: + global _notdiamond_patcher + + if _notdiamond_patcher is not None: + return _notdiamond_patcher + + if settings is None: + settings = IntegrationSettings() + + base = settings.op_settings + + model_select_settings = base.model_copy( + update={"name": base.name or "NotDiamond.model_select"} + ) + async_model_select_settings = base.model_copy( + update={"name": base.name or "NotDiamond.amodel_select"} + ) + patched_client_functions = [ + SymbolPatcher( + lambda: importlib.import_module("notdiamond"), + "NotDiamond.model_select", + passthrough_wrapper(model_select_settings), + ), + SymbolPatcher( + lambda: importlib.import_module("notdiamond"), + "NotDiamond.amodel_select", + passthrough_wrapper(async_model_select_settings), + ), + ] + + llm_config_init_settings = base.model_copy( + update={"name": base.name or "NotDiamond.LLMConfig.__init__"} + ) + llm_config_from_string_settings = base.model_copy( + update={"name": base.name or "NotDiamond.LLMConfig.from_string"} + ) + patched_llmconfig_functions = [ + SymbolPatcher( + lambda: importlib.import_module("notdiamond"), + "LLMConfig.__init__", + passthrough_wrapper(llm_config_init_settings), + ), + SymbolPatcher( + lambda: importlib.import_module("notdiamond"), + "LLMConfig.from_string", + passthrough_wrapper(llm_config_from_string_settings), + ), + ] + + toolkit_custom_router_fit_settings = base.model_copy( + update={"name": base.name or "NotDiamond.toolkit.custom_router.fit"} + ) + toolkit_custom_router_eval_settings = base.model_copy( + update={"name": base.name or "NotDiamond.toolkit.custom_router.eval"} + ) + patched_toolkit_functions = [ + SymbolPatcher( + lambda: importlib.import_module("notdiamond.toolkit.custom_router"), + "CustomRouter.fit", + passthrough_wrapper(toolkit_custom_router_fit_settings), + ), + SymbolPatcher( + lambda: importlib.import_module("notdiamond.toolkit.custom_router"), + "CustomRouter.eval", + passthrough_wrapper(toolkit_custom_router_eval_settings), + ), + ] + + all_patched_functions = ( + patched_client_functions + + patched_toolkit_functions + + patched_llmconfig_functions + ) + + _notdiamond_patcher = MultiPatcher(all_patched_functions) + + return _notdiamond_patcher diff --git a/weave/integrations/vertexai/vertexai_sdk.py b/weave/integrations/vertexai/vertexai_sdk.py index c2f7a9906c7..f016135d775 100644 --- a/weave/integrations/vertexai/vertexai_sdk.py +++ b/weave/integrations/vertexai/vertexai_sdk.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Any, Callable, Optional import weave +from weave.trace.autopatch import IntegrationSettings, OpSettings from weave.trace.op_extensions.accumulator import add_accumulator from weave.trace.patcher import MultiPatcher, SymbolPatcher from weave.trace.serialize import dictify @@ -12,6 +13,9 @@ from vertexai.generative_models import GenerationResponse +_vertexai_patcher: Optional[MultiPatcher] = None + + def vertexai_postprocess_inputs(inputs: dict[str, Any]) -> dict[str, Any]: if "self" in inputs: model_name = ( @@ -81,10 +85,13 @@ def vertexai_on_finish( call.summary.update(summary_update) -def vertexai_wrapper_sync(name: str) -> Callable[[Callable], Callable]: +def vertexai_wrapper_sync(settings: OpSettings) -> Callable[[Callable], Callable]: def wrapper(fn: Callable) -> Callable: - op = weave.op(postprocess_inputs=vertexai_postprocess_inputs)(fn) - op.name = name # type: ignore + op_kwargs = settings.model_copy() + if not op_kwargs.get("postprocess_inputs"): + op_kwargs["postprocess_inputs"] = vertexai_postprocess_inputs + + op = weave.op(fn, **op_kwargs) op._set_on_finish_handler(vertexai_on_finish) return add_accumulator( op, # type: ignore @@ -96,7 +103,7 @@ def wrapper(fn: Callable) -> Callable: return wrapper -def vertexai_wrapper_async(name: str) -> Callable[[Callable], Callable]: +def vertexai_wrapper_async(settings: OpSettings) -> Callable[[Callable], Callable]: def wrapper(fn: Callable) -> Callable: def _fn_wrapper(fn: Callable) -> Callable: @wraps(fn) @@ -105,9 +112,11 @@ async def _async_wrapper(*args: Any, **kwargs: Any) -> Any: return _async_wrapper - "We need to do this so we can check if `stream` is used" - op = weave.op(postprocess_inputs=vertexai_postprocess_inputs)(_fn_wrapper(fn)) - op.name = name # type: ignore + op_kwargs = settings.model_copy() + if not op_kwargs.get("postprocess_inputs"): + op_kwargs["postprocess_inputs"] = vertexai_postprocess_inputs + + op = weave.op(_fn_wrapper(fn), **op_kwargs) op._set_on_finish_handler(vertexai_on_finish) return add_accumulator( op, # type: ignore @@ -119,34 +128,63 @@ async def _async_wrapper(*args: Any, **kwargs: Any) -> Any: return wrapper -vertexai_patcher = MultiPatcher( - [ - SymbolPatcher( - lambda: importlib.import_module("vertexai.generative_models"), - "GenerativeModel.generate_content", - vertexai_wrapper_sync(name="vertexai.GenerativeModel.generate_content"), - ), - SymbolPatcher( - lambda: importlib.import_module("vertexai.generative_models"), - "GenerativeModel.generate_content_async", - vertexai_wrapper_async( - name="vertexai.GenerativeModel.generate_content_async" +def get_vertexai_patcher( + settings: Optional[IntegrationSettings] = None, +) -> MultiPatcher: + global _vertexai_patcher + + if _vertexai_patcher is not None: + return _vertexai_patcher + + if settings is None: + settings = IntegrationSettings() + + base = settings.op_settings + + generate_content_settings = base.model_copy( + update={"name": base.name or "vertexai.GenerativeModel.generate_content"} + ) + generate_content_async_settings = base.model_copy( + update={"name": base.name or "vertexai.GenerativeModel.generate_content_async"} + ) + send_message_settings = base.model_copy( + update={"name": base.name or "vertexai.ChatSession.send_message"} + ) + send_message_async_settings = base.model_copy( + update={"name": base.name or "vertexai.ChatSession.send_message_async"} + ) + generate_images_settings = base.model_copy( + update={"name": base.name or "vertexai.ImageGenerationModel.generate_images"} + ) + + _vertexai_patcher = MultiPatcher( + [ + SymbolPatcher( + lambda: importlib.import_module("vertexai.generative_models"), + "GenerativeModel.generate_content", + vertexai_wrapper_sync(generate_content_settings), + ), + SymbolPatcher( + lambda: importlib.import_module("vertexai.generative_models"), + "GenerativeModel.generate_content_async", + vertexai_wrapper_async(generate_content_async_settings), ), - ), - SymbolPatcher( - lambda: importlib.import_module("vertexai.generative_models"), - "ChatSession.send_message", - vertexai_wrapper_sync(name="vertexai.ChatSession.send_message"), - ), - SymbolPatcher( - lambda: importlib.import_module("vertexai.generative_models"), - "ChatSession.send_message_async", - vertexai_wrapper_async(name="vertexai.ChatSession.send_message_async"), - ), - SymbolPatcher( - lambda: importlib.import_module("vertexai.preview.vision_models"), - "ImageGenerationModel.generate_images", - vertexai_wrapper_sync(name="vertexai.ImageGenerationModel.generate_images"), - ), - ] -) + SymbolPatcher( + lambda: importlib.import_module("vertexai.generative_models"), + "ChatSession.send_message", + vertexai_wrapper_sync(send_message_settings), + ), + SymbolPatcher( + lambda: importlib.import_module("vertexai.generative_models"), + "ChatSession.send_message_async", + vertexai_wrapper_async(send_message_async_settings), + ), + SymbolPatcher( + lambda: importlib.import_module("vertexai.preview.vision_models"), + "ImageGenerationModel.generate_images", + vertexai_wrapper_sync(generate_images_settings), + ), + ] + ) + + return _vertexai_patcher diff --git a/weave/trace/autopatch.py b/weave/trace/autopatch.py index 0619194a224..c79cf3b315d 100644 --- a/weave/trace/autopatch.py +++ b/weave/trace/autopatch.py @@ -34,89 +34,92 @@ class AutopatchSettings(BaseModel): # These will be uncommented as we add support for more integrations. Note that - # anthropic: IntegrationSettings = Field(default_factory=IntegrationSettings) - # cerebras: IntegrationSettings = Field(default_factory=IntegrationSettings) - # cohere: IntegrationSettings = Field(default_factory=IntegrationSettings) - # dspy: IntegrationSettings = Field(default_factory=IntegrationSettings) - # google_ai_studio: IntegrationSettings = Field(default_factory=IntegrationSettings) - # groq: IntegrationSettings = Field(default_factory=IntegrationSettings) - # instructor: IntegrationSettings = Field(default_factory=IntegrationSettings) - # langchain: IntegrationSettings = Field(default_factory=IntegrationSettings) - # litellm: IntegrationSettings = Field(default_factory=IntegrationSettings) - # llamaindex: IntegrationSettings = Field(default_factory=IntegrationSettings) - # mistral: IntegrationSettings = Field(default_factory=IntegrationSettings) - # notdiamond: IntegrationSettings = Field(default_factory=IntegrationSettings) + anthropic: IntegrationSettings = Field(default_factory=IntegrationSettings) + cerebras: IntegrationSettings = Field(default_factory=IntegrationSettings) + cohere: IntegrationSettings = Field(default_factory=IntegrationSettings) + dspy: IntegrationSettings = Field(default_factory=IntegrationSettings) + google_ai_studio: IntegrationSettings = Field(default_factory=IntegrationSettings) + groq: IntegrationSettings = Field(default_factory=IntegrationSettings) + instructor: IntegrationSettings = Field(default_factory=IntegrationSettings) + langchain: IntegrationSettings = Field(default_factory=IntegrationSettings) + litellm: IntegrationSettings = Field(default_factory=IntegrationSettings) + llamaindex: IntegrationSettings = Field(default_factory=IntegrationSettings) + mistral: IntegrationSettings = Field(default_factory=IntegrationSettings) + notdiamond: IntegrationSettings = Field(default_factory=IntegrationSettings) openai: IntegrationSettings = Field(default_factory=IntegrationSettings) - # vertexai: IntegrationSettings = Field(default_factory=IntegrationSettings) + vertexai: IntegrationSettings = Field(default_factory=IntegrationSettings) @validate_call def autopatch(settings: Optional[AutopatchSettings] = None) -> None: - from weave.integrations.anthropic.anthropic_sdk import anthropic_patcher - from weave.integrations.cerebras.cerebras_sdk import cerebras_patcher - from weave.integrations.cohere.cohere_sdk import cohere_patcher - from weave.integrations.dspy.dspy_sdk import dspy_patcher + from weave.integrations.anthropic.anthropic_sdk import get_anthropic_patcher + from weave.integrations.cerebras.cerebras_sdk import get_cerebras_patcher + from weave.integrations.cohere.cohere_sdk import get_cohere_patcher + from weave.integrations.dspy.dspy_sdk import get_dspy_patcher from weave.integrations.google_ai_studio.google_ai_studio_sdk import ( - google_genai_patcher, + get_google_genai_patcher, ) - from weave.integrations.groq.groq_sdk import groq_patcher - from weave.integrations.instructor.instructor_sdk import instructor_patcher + from weave.integrations.groq.groq_sdk import get_groq_patcher + from weave.integrations.instructor.instructor_sdk import get_instructor_patcher from weave.integrations.langchain.langchain import langchain_patcher - from weave.integrations.litellm.litellm import litellm_patcher + from weave.integrations.litellm.litellm import get_litellm_patcher from weave.integrations.llamaindex.llamaindex import llamaindex_patcher - from weave.integrations.mistral import mistral_patcher - from weave.integrations.notdiamond.tracing import notdiamond_patcher + from weave.integrations.mistral import get_mistral_patcher + from weave.integrations.notdiamond.tracing import get_notdiamond_patcher from weave.integrations.openai.openai_sdk import get_openai_patcher - from weave.integrations.vertexai.vertexai_sdk import vertexai_patcher + from weave.integrations.vertexai.vertexai_sdk import get_vertexai_patcher if settings is None: settings = AutopatchSettings() get_openai_patcher(settings.openai).attempt_patch() - mistral_patcher.attempt_patch() - litellm_patcher.attempt_patch() - llamaindex_patcher.attempt_patch() + get_mistral_patcher(settings.mistral).attempt_patch() + get_litellm_patcher(settings.litellm).attempt_patch() + get_anthropic_patcher(settings.anthropic).attempt_patch() + get_groq_patcher(settings.groq).attempt_patch() + get_instructor_patcher(settings.instructor).attempt_patch() + get_dspy_patcher(settings.dspy).attempt_patch() + get_cerebras_patcher(settings.cerebras).attempt_patch() + get_cohere_patcher(settings.cohere).attempt_patch() + get_google_genai_patcher(settings.google_ai_studio).attempt_patch() + get_notdiamond_patcher(settings.notdiamond).attempt_patch() + get_vertexai_patcher(settings.vertexai).attempt_patch() + + # These integrations don't use the op decorator, so there are no patching settings for them langchain_patcher.attempt_patch() - anthropic_patcher.attempt_patch() - groq_patcher.attempt_patch() - instructor_patcher.attempt_patch() - dspy_patcher.attempt_patch() - cerebras_patcher.attempt_patch() - cohere_patcher.attempt_patch() - google_genai_patcher.attempt_patch() - notdiamond_patcher.attempt_patch() - vertexai_patcher.attempt_patch() + llamaindex_patcher.attempt_patch() def reset_autopatch() -> None: - from weave.integrations.anthropic.anthropic_sdk import anthropic_patcher - from weave.integrations.cerebras.cerebras_sdk import cerebras_patcher - from weave.integrations.cohere.cohere_sdk import cohere_patcher - from weave.integrations.dspy.dspy_sdk import dspy_patcher + from weave.integrations.anthropic.anthropic_sdk import get_anthropic_patcher + from weave.integrations.cerebras.cerebras_sdk import get_cerebras_patcher + from weave.integrations.cohere.cohere_sdk import get_cohere_patcher + from weave.integrations.dspy.dspy_sdk import get_dspy_patcher from weave.integrations.google_ai_studio.google_ai_studio_sdk import ( - google_genai_patcher, + get_google_genai_patcher, ) - from weave.integrations.groq.groq_sdk import groq_patcher - from weave.integrations.instructor.instructor_sdk import instructor_patcher + from weave.integrations.groq.groq_sdk import get_groq_patcher + from weave.integrations.instructor.instructor_sdk import get_instructor_patcher from weave.integrations.langchain.langchain import langchain_patcher - from weave.integrations.litellm.litellm import litellm_patcher + from weave.integrations.litellm.litellm import get_litellm_patcher from weave.integrations.llamaindex.llamaindex import llamaindex_patcher - from weave.integrations.mistral import mistral_patcher - from weave.integrations.notdiamond.tracing import notdiamond_patcher + from weave.integrations.mistral import get_mistral_patcher + from weave.integrations.notdiamond.tracing import get_notdiamond_patcher from weave.integrations.openai.openai_sdk import get_openai_patcher - from weave.integrations.vertexai.vertexai_sdk import vertexai_patcher + from weave.integrations.vertexai.vertexai_sdk import get_vertexai_patcher get_openai_patcher().undo_patch() - mistral_patcher.undo_patch() - litellm_patcher.undo_patch() - llamaindex_patcher.undo_patch() + get_mistral_patcher().undo_patch() + get_litellm_patcher().undo_patch() + get_anthropic_patcher().undo_patch() + get_groq_patcher().undo_patch() + get_instructor_patcher().undo_patch() + get_dspy_patcher().undo_patch() + get_cerebras_patcher().undo_patch() + get_cohere_patcher().undo_patch() + get_google_genai_patcher().undo_patch() + get_notdiamond_patcher().undo_patch() + get_vertexai_patcher().undo_patch() + langchain_patcher.undo_patch() - anthropic_patcher.undo_patch() - groq_patcher.undo_patch() - instructor_patcher.undo_patch() - dspy_patcher.undo_patch() - cerebras_patcher.undo_patch() - cohere_patcher.undo_patch() - google_genai_patcher.undo_patch() - notdiamond_patcher.undo_patch() - vertexai_patcher.undo_patch() + llamaindex_patcher.undo_patch() From f3b8c0335fdeafc77ed5183aec70386c7a741708 Mon Sep 17 00:00:00 2001 From: Andrew Truong Date: Wed, 11 Dec 2024 19:09:19 -0500 Subject: [PATCH 3/5] test --- weave/integrations/dspy/dspy_sdk.py | 470 ++++++++++++++--------- weave/integrations/mistral/v0/mistral.py | 4 +- 2 files changed, 292 insertions(+), 182 deletions(-) diff --git a/weave/integrations/dspy/dspy_sdk.py b/weave/integrations/dspy/dspy_sdk.py index b205c77a458..579701960ad 100644 --- a/weave/integrations/dspy/dspy_sdk.py +++ b/weave/integrations/dspy/dspy_sdk.py @@ -1,216 +1,324 @@ import importlib -from typing import Callable +from typing import Callable, Optional import weave +from weave.trace.autopatch import IntegrationSettings, OpSettings from weave.trace.patcher import MultiPatcher, SymbolPatcher +_dspy_patcher: Optional[MultiPatcher] = None -def dspy_wrapper(name: str) -> Callable[[Callable], Callable]: + +def dspy_wrapper(settings: OpSettings) -> Callable[[Callable], Callable]: def wrapper(fn: Callable) -> Callable: - op = weave.op()(fn) - op.name = name # type: ignore + op_kwargs = settings.model_dump() + op = weave.op(fn, **op_kwargs) return op return wrapper def dspy_get_patched_lm_functions( - base_symbol: str, lm_class_name: str + base_symbol: str, lm_class_name: str, settings: OpSettings ) -> list[SymbolPatcher]: patchable_functional_attributes = [ "basic_request", "request", "__call__", ] + basic_request_settings = settings.model_copy( + update={"name": settings.name or f"{base_symbol}.{lm_class_name}.basic_request"} + ) + request_settings = settings.model_copy( + update={"name": settings.name or f"{base_symbol}.{lm_class_name}.request"} + ) + call_settings = settings.model_copy( + update={"name": settings.name or f"{base_symbol}.{lm_class_name}"} + ) return [ SymbolPatcher( get_base_symbol=lambda: importlib.import_module(base_symbol), attribute_name=f"{lm_class_name}.basic_request", - make_new_value=dspy_wrapper(f"dspy.{lm_class_name}.basic_request"), + make_new_value=dspy_wrapper(basic_request_settings), ), SymbolPatcher( get_base_symbol=lambda: importlib.import_module(base_symbol), attribute_name=f"{lm_class_name}.request", - make_new_value=dspy_wrapper(f"dspy.{lm_class_name}.request"), + make_new_value=dspy_wrapper(request_settings), ), SymbolPatcher( get_base_symbol=lambda: importlib.import_module(base_symbol), attribute_name=f"{lm_class_name}.__call__", - make_new_value=dspy_wrapper(f"dspy.{lm_class_name}"), + make_new_value=dspy_wrapper(call_settings), + ), + ] + + +def get_dspy_patcher(settings: Optional[IntegrationSettings] = None) -> MultiPatcher: + global _dspy_patcher + if _dspy_patcher is not None: + return _dspy_patcher + + if settings is None: + settings = IntegrationSettings() + + base = settings.op_settings + + predict_call_settings = base.model_copy( + update={"name": base.name or "dspy.Predict"} + ) + predict_forward_settings = base.model_copy( + update={"name": base.name or "dspy.Predict.forward"} + ) + typed_predictor_call_settings = base.model_copy( + update={"name": base.name or "dspy.TypedPredictor"} + ) + typed_predictor_forward_settings = base.model_copy( + update={"name": base.name or "dspy.TypedPredictor.forward"} + ) + module_call_settings = base.model_copy(update={"name": base.name or "dspy.Module"}) + typed_chain_of_thought_call_settings = base.model_copy( + update={"name": base.name or "dspy.TypedChainOfThought"} + ) + retrieve_call_settings = base.model_copy( + update={"name": base.name or "dspy.Retrieve"} + ) + retrieve_forward_settings = base.model_copy( + update={"name": base.name or "dspy.Retrieve.forward"} + ) + evaluate_call_settings = base.model_copy( + update={"name": base.name or "dspy.evaluate.Evaluate"} + ) + bootstrap_few_shot_compile_settings = base.model_copy( + update={"name": base.name or "dspy.teleprompt.BootstrapFewShot.compile"} + ) + copro_compile_settings = base.model_copy( + update={"name": base.name or "dspy.teleprompt.COPRO.compile"} + ) + ensemble_compile_settings = base.model_copy( + update={"name": base.name or "dspy.teleprompt.Ensemble.compile"} + ) + bootstrap_finetune_compile_settings = base.model_copy( + update={"name": base.name or "dspy.teleprompt.BootstrapFinetune.compile"} + ) + knn_few_shot_compile_settings = base.model_copy( + update={"name": base.name or "dspy.teleprompt.KNNFewShot.compile"} + ) + mipro_compile_settings = base.model_copy( + update={"name": base.name or "dspy.teleprompt.MIPRO.compile"} + ) + bootstrap_few_shot_with_random_search_compile_settings = base.model_copy( + update={ + "name": base.name + or "dspy.teleprompt.BootstrapFewShotWithRandomSearch.compile" + } + ) + signature_optimizer_compile_settings = base.model_copy( + update={"name": base.name or "dspy.teleprompt.SignatureOptimizer.compile"} + ) + bayesian_signature_optimizer_compile_settings = base.model_copy( + update={ + "name": base.name or "dspy.teleprompt.BayesianSignatureOptimizer.compile" + } + ) + signature_opt_typed_optimize_signature_settings = base.model_copy( + update={ + "name": base.name + or "dspy.teleprompt.signature_opt_typed.optimize_signature" + } + ) + bootstrap_few_shot_with_optuna_compile_settings = base.model_copy( + update={ + "name": base.name or "dspy.teleprompt.BootstrapFewShotWithOptuna.compile" + } + ) + labeled_few_shot_compile_settings = base.model_copy( + update={"name": base.name or "dspy.teleprompt.LabeledFewShot.compile"} + ) + + patched_functions = [ + SymbolPatcher( + get_base_symbol=lambda: importlib.import_module("dspy"), + attribute_name="Predict.__call__", + make_new_value=dspy_wrapper(predict_call_settings), + ), + SymbolPatcher( + get_base_symbol=lambda: importlib.import_module("dspy"), + attribute_name="Predict.forward", + make_new_value=dspy_wrapper(predict_forward_settings), + ), + SymbolPatcher( + get_base_symbol=lambda: importlib.import_module("dspy"), + attribute_name="TypedPredictor.__call__", + make_new_value=dspy_wrapper(typed_predictor_call_settings), + ), + SymbolPatcher( + get_base_symbol=lambda: importlib.import_module("dspy"), + attribute_name="TypedPredictor.forward", + make_new_value=dspy_wrapper(typed_predictor_forward_settings), + ), + SymbolPatcher( + get_base_symbol=lambda: importlib.import_module("dspy"), + attribute_name="Module.__call__", + make_new_value=dspy_wrapper(module_call_settings), + ), + SymbolPatcher( + get_base_symbol=lambda: importlib.import_module("dspy"), + attribute_name="TypedChainOfThought.__call__", + make_new_value=dspy_wrapper(typed_chain_of_thought_call_settings), + ), + SymbolPatcher( + get_base_symbol=lambda: importlib.import_module("dspy"), + attribute_name="Retrieve.__call__", + make_new_value=dspy_wrapper(retrieve_call_settings), + ), + SymbolPatcher( + get_base_symbol=lambda: importlib.import_module("dspy"), + attribute_name="Retrieve.forward", + make_new_value=dspy_wrapper(retrieve_forward_settings), + ), + SymbolPatcher( + get_base_symbol=lambda: importlib.import_module("dspy.evaluate.evaluate"), + attribute_name="Evaluate.__call__", + make_new_value=dspy_wrapper(evaluate_call_settings), + ), + SymbolPatcher( + get_base_symbol=lambda: importlib.import_module("dspy.teleprompt"), + attribute_name="BootstrapFewShot.compile", + make_new_value=dspy_wrapper(bootstrap_few_shot_compile_settings), + ), + SymbolPatcher( + get_base_symbol=lambda: importlib.import_module("dspy.teleprompt"), + attribute_name="COPRO.compile", + make_new_value=dspy_wrapper(copro_compile_settings), + ), + SymbolPatcher( + get_base_symbol=lambda: importlib.import_module("dspy.teleprompt"), + attribute_name="Ensemble.compile", + make_new_value=dspy_wrapper(ensemble_compile_settings), + ), + SymbolPatcher( + get_base_symbol=lambda: importlib.import_module("dspy.teleprompt"), + attribute_name="BootstrapFinetune.compile", + make_new_value=dspy_wrapper(bootstrap_finetune_compile_settings), + ), + SymbolPatcher( + get_base_symbol=lambda: importlib.import_module("dspy.teleprompt"), + attribute_name="KNNFewShot.compile", + make_new_value=dspy_wrapper(knn_few_shot_compile_settings), + ), + SymbolPatcher( + get_base_symbol=lambda: importlib.import_module("dspy.teleprompt"), + attribute_name="MIPRO.compile", + make_new_value=dspy_wrapper(mipro_compile_settings), + ), + SymbolPatcher( + get_base_symbol=lambda: importlib.import_module("dspy.teleprompt"), + attribute_name="BootstrapFewShotWithRandomSearch.compile", + make_new_value=dspy_wrapper( + bootstrap_few_shot_with_random_search_compile_settings + ), + ), + SymbolPatcher( + get_base_symbol=lambda: importlib.import_module("dspy.teleprompt"), + attribute_name="SignatureOptimizer.compile", + make_new_value=dspy_wrapper(signature_optimizer_compile_settings), + ), + SymbolPatcher( + get_base_symbol=lambda: importlib.import_module("dspy.teleprompt"), + attribute_name="BayesianSignatureOptimizer.compile", + make_new_value=dspy_wrapper(bayesian_signature_optimizer_compile_settings), + ), + SymbolPatcher( + get_base_symbol=lambda: importlib.import_module( + "dspy.teleprompt.signature_opt_typed" + ), + attribute_name="optimize_signature", + make_new_value=dspy_wrapper( + signature_opt_typed_optimize_signature_settings + ), + ), + SymbolPatcher( + get_base_symbol=lambda: importlib.import_module("dspy.teleprompt"), + attribute_name="BootstrapFewShotWithOptuna.compile", + make_new_value=dspy_wrapper( + bootstrap_few_shot_with_optuna_compile_settings + ), + ), + SymbolPatcher( + get_base_symbol=lambda: importlib.import_module("dspy.teleprompt"), + attribute_name="LabeledFewShot.compile", + make_new_value=dspy_wrapper(labeled_few_shot_compile_settings), + ), + ] + + # Patch LM classes + patched_functions += dspy_get_patched_lm_functions( + base_symbol="dspy", lm_class_name="AzureOpenAI", settings=base + ) + patched_functions += dspy_get_patched_lm_functions( + base_symbol="dspy", lm_class_name="OpenAI", settings=base + ) + patched_functions += dspy_get_patched_lm_functions( + base_symbol="dspy", lm_class_name="Cohere", settings=base + ) + patched_functions += dspy_get_patched_lm_functions( + base_symbol="dspy", lm_class_name="Clarifai", settings=base + ) + patched_functions += dspy_get_patched_lm_functions( + base_symbol="dspy", lm_class_name="Google", settings=base + ) + patched_functions += dspy_get_patched_lm_functions( + base_symbol="dspy", lm_class_name="HFClientTGI", settings=base + ) + patched_functions += dspy_get_patched_lm_functions( + base_symbol="dspy", lm_class_name="HFClientVLLM", settings=base + ) + patched_functions += dspy_get_patched_lm_functions( + base_symbol="dspy", lm_class_name="Anyscale", settings=base + ) + patched_functions += dspy_get_patched_lm_functions( + base_symbol="dspy", lm_class_name="Together", settings=base + ) + patched_functions += dspy_get_patched_lm_functions( + base_symbol="dspy", lm_class_name="OllamaLocal", settings=base + ) + + databricks_basic_request_settings = base.model_copy( + update={"name": base.name or "dspy.Databricks.basic_request"} + ) + databricks_call_settings = base.model_copy( + update={"name": base.name or "dspy.Databricks"} + ) + colbertv2_call_settings = base.model_copy( + update={"name": base.name or "dspy.ColBERTv2"} + ) + pyserini_call_settings = base.model_copy( + update={"name": base.name or "dspy.Pyserini"} + ) + + patched_functions += [ + SymbolPatcher( + get_base_symbol=lambda: importlib.import_module("dspy"), + attribute_name="Databricks.basic_request", + make_new_value=dspy_wrapper(databricks_basic_request_settings), + ), + SymbolPatcher( + get_base_symbol=lambda: importlib.import_module("dspy"), + attribute_name="Databricks.__call__", + make_new_value=dspy_wrapper(databricks_call_settings), + ), + SymbolPatcher( + get_base_symbol=lambda: importlib.import_module("dspy"), + attribute_name="ColBERTv2.__call__", + make_new_value=dspy_wrapper(colbertv2_call_settings), + ), + SymbolPatcher( + get_base_symbol=lambda: importlib.import_module("dspy"), + attribute_name="Pyserini.__call__", + make_new_value=dspy_wrapper(pyserini_call_settings), ), ] + _dspy_patcher = MultiPatcher(patched_functions) -patched_functions = [ - SymbolPatcher( - get_base_symbol=lambda: importlib.import_module("dspy"), - attribute_name="Predict.__call__", - make_new_value=dspy_wrapper("dspy.Predict"), - ), - SymbolPatcher( - get_base_symbol=lambda: importlib.import_module("dspy"), - attribute_name="Predict.forward", - make_new_value=dspy_wrapper("dspy.Predict.forward"), - ), - SymbolPatcher( - get_base_symbol=lambda: importlib.import_module("dspy"), - attribute_name="TypedPredictor.__call__", - make_new_value=dspy_wrapper("dspy.TypedPredictor"), - ), - SymbolPatcher( - get_base_symbol=lambda: importlib.import_module("dspy"), - attribute_name="TypedPredictor.forward", - make_new_value=dspy_wrapper("dspy.TypedPredictor.forward"), - ), - SymbolPatcher( - get_base_symbol=lambda: importlib.import_module("dspy"), - attribute_name="Module.__call__", - make_new_value=dspy_wrapper("dspy.Module"), - ), - SymbolPatcher( - get_base_symbol=lambda: importlib.import_module("dspy"), - attribute_name="TypedChainOfThought.__call__", - make_new_value=dspy_wrapper("dspy.TypedChainOfThought"), - ), - SymbolPatcher( - get_base_symbol=lambda: importlib.import_module("dspy"), - attribute_name="Retrieve.__call__", - make_new_value=dspy_wrapper("dspy.Retrieve"), - ), - SymbolPatcher( - get_base_symbol=lambda: importlib.import_module("dspy"), - attribute_name="Retrieve.forward", - make_new_value=dspy_wrapper("dspy.Retrieve.forward"), - ), - SymbolPatcher( - get_base_symbol=lambda: importlib.import_module("dspy.evaluate.evaluate"), - attribute_name="Evaluate.__call__", - make_new_value=dspy_wrapper("dspy.evaluate.Evaluate"), - ), - SymbolPatcher( - get_base_symbol=lambda: importlib.import_module("dspy.teleprompt"), - attribute_name="BootstrapFewShot.compile", - make_new_value=dspy_wrapper("dspy.teleprompt.BootstrapFewShot.compile"), - ), - SymbolPatcher( - get_base_symbol=lambda: importlib.import_module("dspy.teleprompt"), - attribute_name="COPRO.compile", - make_new_value=dspy_wrapper("dspy.teleprompt.COPRO.compile"), - ), - SymbolPatcher( - get_base_symbol=lambda: importlib.import_module("dspy.teleprompt"), - attribute_name="Ensemble.compile", - make_new_value=dspy_wrapper("dspy.teleprompt.Ensemble.compile"), - ), - SymbolPatcher( - get_base_symbol=lambda: importlib.import_module("dspy.teleprompt"), - attribute_name="BootstrapFinetune.compile", - make_new_value=dspy_wrapper("dspy.teleprompt.BootstrapFinetune.compile"), - ), - SymbolPatcher( - get_base_symbol=lambda: importlib.import_module("dspy.teleprompt"), - attribute_name="KNNFewShot.compile", - make_new_value=dspy_wrapper("dspy.teleprompt.KNNFewShot.compile"), - ), - SymbolPatcher( - get_base_symbol=lambda: importlib.import_module("dspy.teleprompt"), - attribute_name="MIPRO.compile", - make_new_value=dspy_wrapper("dspy.teleprompt.MIPRO.compile"), - ), - SymbolPatcher( - get_base_symbol=lambda: importlib.import_module("dspy.teleprompt"), - attribute_name="BootstrapFewShotWithRandomSearch.compile", - make_new_value=dspy_wrapper( - "dspy.teleprompt.BootstrapFewShotWithRandomSearch.compile" - ), - ), - SymbolPatcher( - get_base_symbol=lambda: importlib.import_module("dspy.teleprompt"), - attribute_name="SignatureOptimizer.compile", - make_new_value=dspy_wrapper("dspy.teleprompt.SignatureOptimizer.compile"), - ), - SymbolPatcher( - get_base_symbol=lambda: importlib.import_module("dspy.teleprompt"), - attribute_name="BayesianSignatureOptimizer.compile", - make_new_value=dspy_wrapper( - "dspy.teleprompt.BayesianSignatureOptimizer.compile" - ), - ), - SymbolPatcher( - get_base_symbol=lambda: importlib.import_module( - "dspy.teleprompt.signature_opt_typed" - ), - attribute_name="optimize_signature", - make_new_value=dspy_wrapper( - "dspy.teleprompt.signature_opt_typed.optimize_signature" - ), - ), - SymbolPatcher( - get_base_symbol=lambda: importlib.import_module("dspy.teleprompt"), - attribute_name="BootstrapFewShotWithOptuna.compile", - make_new_value=dspy_wrapper( - "dspy.teleprompt.BootstrapFewShotWithOptuna.compile" - ), - ), - SymbolPatcher( - get_base_symbol=lambda: importlib.import_module("dspy.teleprompt"), - attribute_name="LabeledFewShot.compile", - make_new_value=dspy_wrapper("dspy.teleprompt.LabeledFewShot.compile"), - ), -] - -# Patch LM classes -patched_functions += dspy_get_patched_lm_functions( - base_symbol="dspy", lm_class_name="AzureOpenAI" -) -patched_functions += dspy_get_patched_lm_functions( - base_symbol="dspy", lm_class_name="OpenAI" -) -patched_functions += dspy_get_patched_lm_functions( - base_symbol="dspy", lm_class_name="Cohere" -) -patched_functions += dspy_get_patched_lm_functions( - base_symbol="dspy", lm_class_name="Clarifai" -) -patched_functions += dspy_get_patched_lm_functions( - base_symbol="dspy", lm_class_name="Google" -) -patched_functions += dspy_get_patched_lm_functions( - base_symbol="dspy", lm_class_name="HFClientTGI" -) -patched_functions += dspy_get_patched_lm_functions( - base_symbol="dspy", lm_class_name="HFClientVLLM" -) -patched_functions += dspy_get_patched_lm_functions( - base_symbol="dspy", lm_class_name="Anyscale" -) -patched_functions += dspy_get_patched_lm_functions( - base_symbol="dspy", lm_class_name="Together" -) -patched_functions += dspy_get_patched_lm_functions( - base_symbol="dspy", lm_class_name="OllamaLocal" -) -patched_functions += [ - SymbolPatcher( - get_base_symbol=lambda: importlib.import_module("dspy"), - attribute_name="Databricks.basic_request", - make_new_value=dspy_wrapper("dspy.Databricks.basic_request"), - ), - SymbolPatcher( - get_base_symbol=lambda: importlib.import_module("dspy"), - attribute_name="Databricks.__call__", - make_new_value=dspy_wrapper("dspy.Databricks"), - ), - SymbolPatcher( - get_base_symbol=lambda: importlib.import_module("dspy"), - attribute_name="ColBERTv2.__call__", - make_new_value=dspy_wrapper("dspy.ColBERTv2"), - ), - SymbolPatcher( - get_base_symbol=lambda: importlib.import_module("dspy"), - attribute_name="Pyserini.__call__", - make_new_value=dspy_wrapper("dspy.Pyserini"), - ), -] - -dspy_patcher = MultiPatcher(patched_functions) + return _dspy_patcher diff --git a/weave/integrations/mistral/v0/mistral.py b/weave/integrations/mistral/v0/mistral.py index 70c81f449e2..8281cf51270 100644 --- a/weave/integrations/mistral/v0/mistral.py +++ b/weave/integrations/mistral/v0/mistral.py @@ -116,7 +116,7 @@ def get_mistral_patcher(settings: Optional[IntegrationSettings] = None) -> Multi update={"name": base.name or "mistralai.async_client.chat_stream"} ) - mistral_patcher = MultiPatcher( + _mistral_patcher = MultiPatcher( [ # Patch the sync, non-streaming chat method SymbolPatcher( @@ -144,3 +144,5 @@ def get_mistral_patcher(settings: Optional[IntegrationSettings] = None) -> Multi ), ] ) + + return _mistral_patcher From 1fc178cf4748871c1098ac4e1354eee2fd0fab6e Mon Sep 17 00:00:00 2001 From: Andrew Truong Date: Wed, 11 Dec 2024 19:22:06 -0500 Subject: [PATCH 4/5] test --- weave/integrations/anthropic/anthropic_sdk.py | 38 ++++++++++--------- weave/integrations/cerebras/cerebras_sdk.py | 22 ++++++----- weave/integrations/cohere/cohere_sdk.py | 32 ++++++++-------- weave/integrations/dspy/dspy_sdk.py | 21 ++++++---- .../google_ai_studio/google_ai_studio_sdk.py | 30 ++++++++------- weave/integrations/groq/groq_sdk.py | 30 +++++++++------ .../integrations/instructor/instructor_sdk.py | 21 +++++----- weave/integrations/litellm/litellm.py | 28 ++++++++------ weave/integrations/mistral/v0/mistral.py | 28 ++++++++------ weave/integrations/mistral/v1/mistral.py | 28 ++++++++------ weave/integrations/notdiamond/tracing.py | 22 ++++++----- weave/integrations/vertexai/vertexai_sdk.py | 28 +++++++------- weave/trace/autopatch.py | 7 ++-- 13 files changed, 192 insertions(+), 143 deletions(-) diff --git a/weave/integrations/anthropic/anthropic_sdk.py b/weave/integrations/anthropic/anthropic_sdk.py index e583e8de1bc..9cd06f53259 100644 --- a/weave/integrations/anthropic/anthropic_sdk.py +++ b/weave/integrations/anthropic/anthropic_sdk.py @@ -1,24 +1,26 @@ +from __future__ import annotations + import importlib from collections.abc import AsyncIterator, Iterator from functools import wraps -from typing import TYPE_CHECKING, Any, Callable, Optional, Union +from typing import TYPE_CHECKING, Any, Callable import weave from weave.trace.autopatch import IntegrationSettings, OpSettings from weave.trace.op_extensions.accumulator import _IteratorWrapper, add_accumulator -from weave.trace.patcher import MultiPatcher, SymbolPatcher +from weave.trace.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher if TYPE_CHECKING: from anthropic.lib.streaming import MessageStream from anthropic.types import Message, MessageStreamEvent -_anthropic_patcher: Optional[MultiPatcher] = None +_anthropic_patcher: MultiPatcher | None = None def anthropic_accumulator( - acc: Optional["Message"], - value: "MessageStreamEvent", -) -> "Message": + acc: Message | None, + value: MessageStreamEvent, +) -> Message: from anthropic.types import ( ContentBlockDeltaEvent, Message, @@ -116,9 +118,9 @@ async def _async_wrapper(*args: Any, **kwargs: Any) -> Any: def anthropic_stream_accumulator( - acc: Optional["Message"], - value: "MessageStream", -) -> "Message": + acc: Message | None, + value: MessageStream, +) -> Message: from anthropic.lib.streaming._types import MessageStopEvent if acc is None: @@ -143,7 +145,7 @@ def __getattr__(self, name: str) -> Any: return object.__getattribute__(self, name) return getattr(self._iterator_or_ctx_manager, name) - def __stream_text__(self) -> Union[Iterator[str], AsyncIterator[str]]: + def __stream_text__(self) -> Iterator[str] | AsyncIterator[str]: if isinstance(self._iterator_or_ctx_manager, AsyncIterator): return self.__async_stream_text__() else: @@ -160,7 +162,7 @@ async def __async_stream_text__(self) -> AsyncIterator[str]: # type: ignore yield chunk.delta.text # type: ignore @property - def text_stream(self) -> Union[Iterator[str], AsyncIterator[str]]: + def text_stream(self) -> Iterator[str] | AsyncIterator[str]: return self.__stream_text__() @@ -179,16 +181,18 @@ def wrapper(fn: Callable) -> Callable: def get_anthropic_patcher( - settings: Optional[IntegrationSettings] = None, -) -> MultiPatcher: - global _anthropic_patcher + settings: IntegrationSettings | None = None, +) -> MultiPatcher | NoOpPatcher: + if settings is None: + settings = IntegrationSettings() + + if not settings.enabled: + return NoOpPatcher() + global _anthropic_patcher if _anthropic_patcher is not None: return _anthropic_patcher - if settings is None: - settings = IntegrationSettings() - base = settings.op_settings messages_create_settings = base.model_copy( diff --git a/weave/integrations/cerebras/cerebras_sdk.py b/weave/integrations/cerebras/cerebras_sdk.py index 1b6fee49006..a2096a184e7 100644 --- a/weave/integrations/cerebras/cerebras_sdk.py +++ b/weave/integrations/cerebras/cerebras_sdk.py @@ -1,12 +1,14 @@ +from __future__ import annotations + import importlib from functools import wraps -from typing import Any, Callable, Optional +from typing import Any, Callable import weave from weave.trace.autopatch import IntegrationSettings, OpSettings -from weave.trace.patcher import MultiPatcher, SymbolPatcher +from weave.trace.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher -_cerebras_patcher: Optional[MultiPatcher] = None +_cerebras_patcher: MultiPatcher | None = None def create_wrapper_sync(settings: OpSettings) -> Callable[[Callable], Callable]: @@ -35,16 +37,18 @@ async def _async_wrapper(*args: Any, **kwargs: Any) -> Any: def get_cerebras_patcher( - settings: Optional[IntegrationSettings] = None, -) -> MultiPatcher: - global _cerebras_patcher + settings: IntegrationSettings | None = None, +) -> MultiPatcher | NoOpPatcher: + if settings is None: + settings = IntegrationSettings() + + if not settings.enabled: + return NoOpPatcher() + global _cerebras_patcher if _cerebras_patcher is not None: return _cerebras_patcher - if settings is None: - settings = IntegrationSettings() - base = settings.op_settings create_settings = base.model_copy( diff --git a/weave/integrations/cohere/cohere_sdk.py b/weave/integrations/cohere/cohere_sdk.py index e554e659478..a9b216c070b 100644 --- a/weave/integrations/cohere/cohere_sdk.py +++ b/weave/integrations/cohere/cohere_sdk.py @@ -1,24 +1,23 @@ +from __future__ import annotations + import importlib from functools import wraps -from typing import TYPE_CHECKING, Any, Callable, Optional +from typing import TYPE_CHECKING, Any, Callable import weave from weave.trace.autopatch import IntegrationSettings, OpSettings from weave.trace.op_extensions.accumulator import add_accumulator -from weave.trace.patcher import MultiPatcher, SymbolPatcher +from weave.trace.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher if TYPE_CHECKING: from cohere.types.non_streamed_chat_response import NonStreamedChatResponse from cohere.v2.types.non_streamed_chat_response2 import NonStreamedChatResponse2 -_cohere_patcher: Optional[MultiPatcher] = None +_cohere_patcher: MultiPatcher | None = None -def cohere_accumulator( - acc: Optional[dict], - value: Any, -) -> "NonStreamedChatResponse": +def cohere_accumulator(acc: dict | None, value: Any) -> NonStreamedChatResponse: # don't need to accumulate, is build-in by cohere! # https://docs.cohere.com/docs/streaming # A stream-end event is the final event of the stream, and is returned only when streaming is finished. @@ -35,10 +34,7 @@ def cohere_accumulator( return acc -def cohere_accumulator_v2( - acc: Optional[dict], - value: Any, -) -> "NonStreamedChatResponse2": +def cohere_accumulator_v2(acc: dict | None, value: Any) -> NonStreamedChatResponse2: from cohere.v2.types.assistant_message_response import AssistantMessageResponse from cohere.v2.types.non_streamed_chat_response2 import NonStreamedChatResponse2 @@ -185,15 +181,19 @@ def wrapper(fn: Callable) -> Callable: return wrapper -def get_cohere_patcher(settings: Optional[IntegrationSettings] = None) -> MultiPatcher: - global _cohere_patcher +def get_cohere_patcher( + settings: IntegrationSettings | None = None, +) -> MultiPatcher | NoOpPatcher: + if settings is None: + settings = IntegrationSettings() + + if not settings.enabled: + return NoOpPatcher() + global _cohere_patcher if _cohere_patcher is not None: return _cohere_patcher - if settings is None: - settings = IntegrationSettings() - base = settings.op_settings chat_settings = base.model_copy(update={"name": base.name or "cohere.Client.chat"}) diff --git a/weave/integrations/dspy/dspy_sdk.py b/weave/integrations/dspy/dspy_sdk.py index 579701960ad..25293b0f494 100644 --- a/weave/integrations/dspy/dspy_sdk.py +++ b/weave/integrations/dspy/dspy_sdk.py @@ -1,11 +1,13 @@ +from __future__ import annotations + import importlib -from typing import Callable, Optional +from typing import Callable import weave from weave.trace.autopatch import IntegrationSettings, OpSettings -from weave.trace.patcher import MultiPatcher, SymbolPatcher +from weave.trace.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher -_dspy_patcher: Optional[MultiPatcher] = None +_dspy_patcher: MultiPatcher | None = None def dspy_wrapper(settings: OpSettings) -> Callable[[Callable], Callable]: @@ -53,14 +55,19 @@ def dspy_get_patched_lm_functions( ] -def get_dspy_patcher(settings: Optional[IntegrationSettings] = None) -> MultiPatcher: +def get_dspy_patcher( + settings: IntegrationSettings | None = None, +) -> MultiPatcher | NoOpPatcher: + if settings is None: + settings = IntegrationSettings() + + if not settings.enabled: + return NoOpPatcher() + global _dspy_patcher if _dspy_patcher is not None: return _dspy_patcher - if settings is None: - settings = IntegrationSettings() - base = settings.op_settings predict_call_settings = base.model_copy( diff --git a/weave/integrations/google_ai_studio/google_ai_studio_sdk.py b/weave/integrations/google_ai_studio/google_ai_studio_sdk.py index d86fe5b11f3..7c4e6d2a740 100644 --- a/weave/integrations/google_ai_studio/google_ai_studio_sdk.py +++ b/weave/integrations/google_ai_studio/google_ai_studio_sdk.py @@ -1,18 +1,20 @@ +from __future__ import annotations + import importlib from functools import wraps -from typing import TYPE_CHECKING, Any, Callable, Optional +from typing import TYPE_CHECKING, Any, Callable import weave from weave.trace.autopatch import IntegrationSettings, OpSettings from weave.trace.op_extensions.accumulator import add_accumulator -from weave.trace.patcher import MultiPatcher, SymbolPatcher +from weave.trace.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher from weave.trace.serialize import dictify from weave.trace.weave_client import Call if TYPE_CHECKING: from google.generativeai.types.generation_types import GenerateContentResponse -_google_genai_patcher: Optional[MultiPatcher] = None +_google_genai_patcher: MultiPatcher | None = None def gemini_postprocess_inputs(inputs: dict[str, Any]) -> dict[str, Any]: @@ -22,8 +24,8 @@ def gemini_postprocess_inputs(inputs: dict[str, Any]) -> dict[str, Any]: def gemini_accumulator( - acc: Optional["GenerateContentResponse"], value: "GenerateContentResponse" -) -> "GenerateContentResponse": + acc: GenerateContentResponse | None, value: GenerateContentResponse +) -> GenerateContentResponse: if acc is None: return value @@ -67,9 +69,7 @@ def gemini_accumulator( return acc -def gemini_on_finish( - call: Call, output: Any, exception: Optional[BaseException] -) -> None: +def gemini_on_finish(call: Call, output: Any, exception: BaseException | None) -> None: if "model_name" in call.inputs["self"]: original_model_name = call.inputs["self"]["model_name"] elif "model" in call.inputs["self"]: @@ -136,16 +136,18 @@ async def _async_wrapper(*args: Any, **kwargs: Any) -> Any: def get_google_genai_patcher( - settings: Optional[IntegrationSettings] = None, -) -> MultiPatcher: - global _google_genai_patcher + settings: IntegrationSettings | None = None, +) -> MultiPatcher | NoOpPatcher: + if settings is None: + settings = IntegrationSettings() + + if not settings.enabled: + return NoOpPatcher() + global _google_genai_patcher if _google_genai_patcher is not None: return _google_genai_patcher - if settings is None: - settings = IntegrationSettings() - base = settings.op_settings generate_content_settings = base.model_copy( diff --git a/weave/integrations/groq/groq_sdk.py b/weave/integrations/groq/groq_sdk.py index fe8db33bcf5..c5c07fd705f 100644 --- a/weave/integrations/groq/groq_sdk.py +++ b/weave/integrations/groq/groq_sdk.py @@ -1,21 +1,23 @@ +from __future__ import annotations + import importlib -from typing import TYPE_CHECKING, Callable, Optional +from typing import TYPE_CHECKING, Callable +import weave from weave.trace.autopatch import IntegrationSettings, OpSettings +from weave.trace.op_extensions.accumulator import add_accumulator +from weave.trace.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher if TYPE_CHECKING: from groq.types.chat import ChatCompletion, ChatCompletionChunk -import weave -from weave.trace.op_extensions.accumulator import add_accumulator -from weave.trace.patcher import MultiPatcher, SymbolPatcher -_groq_patcher: Optional[MultiPatcher] = None +_groq_patcher: MultiPatcher | None = None def groq_accumulator( - acc: Optional["ChatCompletion"], value: "ChatCompletionChunk" -) -> "ChatCompletion": + acc: ChatCompletion | None, value: ChatCompletionChunk +) -> ChatCompletion: from groq.types.chat import ChatCompletion, ChatCompletionMessage from groq.types.chat.chat_completion import Choice from groq.types.chat.chat_completion_chunk import Choice as ChoiceChunk @@ -100,15 +102,19 @@ def wrapper(fn: Callable) -> Callable: return wrapper -def get_groq_patcher(settings: Optional[IntegrationSettings] = None) -> MultiPatcher: - global _groq_patcher +def get_groq_patcher( + settings: IntegrationSettings | None = None, +) -> MultiPatcher | NoOpPatcher: + if settings is None: + settings = IntegrationSettings() + + if not settings.enabled: + return NoOpPatcher() + global _groq_patcher if _groq_patcher is not None: return _groq_patcher - if settings is None: - settings = IntegrationSettings() - base = settings.op_settings chat_completions_settings = base.model_copy( diff --git a/weave/integrations/instructor/instructor_sdk.py b/weave/integrations/instructor/instructor_sdk.py index 2716d570cbb..00dfde029c9 100644 --- a/weave/integrations/instructor/instructor_sdk.py +++ b/weave/integrations/instructor/instructor_sdk.py @@ -1,26 +1,29 @@ +from __future__ import annotations + import importlib -from typing import Optional from weave.trace.autopatch import IntegrationSettings -from weave.trace.patcher import MultiPatcher, SymbolPatcher +from weave.trace.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher from .instructor_iterable_utils import instructor_wrapper_async, instructor_wrapper_sync from .instructor_partial_utils import instructor_wrapper_partial -_instructor_patcher: Optional[MultiPatcher] = None +_instructor_patcher: MultiPatcher | None = None def get_instructor_patcher( - settings: Optional[IntegrationSettings] = None, -) -> MultiPatcher: - global _instructor_patcher + settings: IntegrationSettings | None = None, +) -> MultiPatcher | NoOpPatcher: + if settings is None: + settings = IntegrationSettings() + + if not settings.enabled: + return NoOpPatcher() + global _instructor_patcher if _instructor_patcher is not None: return _instructor_patcher - if settings is None: - settings = IntegrationSettings() - base = settings.op_settings create_settings = base.model_copy(update={"name": base.name or "Instructor.create"}) diff --git a/weave/integrations/litellm/litellm.py b/weave/integrations/litellm/litellm.py index 8a1818820a0..9ae6e492c84 100644 --- a/weave/integrations/litellm/litellm.py +++ b/weave/integrations/litellm/litellm.py @@ -1,22 +1,24 @@ +from __future__ import annotations + import importlib -from typing import TYPE_CHECKING, Any, Callable, Optional +from typing import TYPE_CHECKING, Any, Callable import weave from weave.trace.autopatch import IntegrationSettings, OpSettings from weave.trace.op_extensions.accumulator import add_accumulator -from weave.trace.patcher import MultiPatcher, SymbolPatcher +from weave.trace.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher if TYPE_CHECKING: from litellm.utils import ModelResponse -_litellm_patcher: Optional[MultiPatcher] = None +_litellm_patcher: MultiPatcher | None = None # This accumulator is nearly identical to the mistral accumulator, just with different types. def litellm_accumulator( - acc: Optional["ModelResponse"], - value: "ModelResponse", -) -> "ModelResponse": + acc: ModelResponse | None, + value: ModelResponse, +) -> ModelResponse: # This import should be safe at this point from litellm.utils import Choices, Message, ModelResponse, Usage @@ -99,15 +101,19 @@ def litellm_wrapper(fn: Callable) -> Callable: return litellm_wrapper -def get_litellm_patcher(settings: Optional[IntegrationSettings] = None) -> MultiPatcher: - global _litellm_patcher +def get_litellm_patcher( + settings: IntegrationSettings | None = None, +) -> MultiPatcher | NoOpPatcher: + if settings is None: + settings = IntegrationSettings() + + if not settings.enabled: + return NoOpPatcher() + global _litellm_patcher if _litellm_patcher is not None: return _litellm_patcher - if settings is None: - settings = IntegrationSettings() - base = settings.op_settings completion_settings = base.model_copy( diff --git a/weave/integrations/mistral/v0/mistral.py b/weave/integrations/mistral/v0/mistral.py index 8281cf51270..70a3fa183bb 100644 --- a/weave/integrations/mistral/v0/mistral.py +++ b/weave/integrations/mistral/v0/mistral.py @@ -1,10 +1,12 @@ +from __future__ import annotations + import importlib -from typing import TYPE_CHECKING, Callable, Optional +from typing import TYPE_CHECKING, Callable import weave from weave.trace.autopatch import IntegrationSettings, OpSettings from weave.trace.op_extensions.accumulator import add_accumulator -from weave.trace.patcher import MultiPatcher, SymbolPatcher +from weave.trace.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher if TYPE_CHECKING: from mistralai.models.chat_completion import ( @@ -12,13 +14,13 @@ ChatCompletionStreamResponse, ) -_mistral_patcher: Optional[MultiPatcher] = None +_mistral_patcher: MultiPatcher | None = None def mistral_accumulator( - acc: Optional["ChatCompletionResponse"], - value: "ChatCompletionStreamResponse", -) -> "ChatCompletionResponse": + acc: ChatCompletionResponse | None, + value: ChatCompletionStreamResponse, +) -> ChatCompletionResponse: # This import should be safe at this point from mistralai.models.chat_completion import ( ChatCompletionResponse, @@ -94,15 +96,19 @@ def wrapper(fn: Callable) -> Callable: return wrapper -def get_mistral_patcher(settings: Optional[IntegrationSettings] = None) -> MultiPatcher: - global _mistral_patcher +def get_mistral_patcher( + settings: IntegrationSettings | None = None, +) -> MultiPatcher | NoOpPatcher: + if settings is None: + settings = IntegrationSettings() + + if not settings.enabled: + return NoOpPatcher() + global _mistral_patcher if _mistral_patcher is not None: return _mistral_patcher - if settings is None: - settings = IntegrationSettings() - base = settings.op_settings chat_settings = base.model_copy(update={"name": base.name or "mistralai.chat"}) diff --git a/weave/integrations/mistral/v1/mistral.py b/weave/integrations/mistral/v1/mistral.py index a7824803c40..d52d42af3c4 100644 --- a/weave/integrations/mistral/v1/mistral.py +++ b/weave/integrations/mistral/v1/mistral.py @@ -1,10 +1,12 @@ +from __future__ import annotations + import importlib -from typing import TYPE_CHECKING, Callable, Optional +from typing import TYPE_CHECKING, Callable import weave from weave.trace.autopatch import IntegrationSettings, OpSettings from weave.trace.op_extensions.accumulator import add_accumulator -from weave.trace.patcher import MultiPatcher, SymbolPatcher +from weave.trace.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher if TYPE_CHECKING: from mistralai.models import ( @@ -12,13 +14,13 @@ CompletionEvent, ) -_mistral_patcher: Optional[MultiPatcher] = None +_mistral_patcher: MultiPatcher | None = None def mistral_accumulator( - acc: Optional["ChatCompletionResponse"], - value: "CompletionEvent", -) -> "ChatCompletionResponse": + acc: ChatCompletionResponse | None, + value: CompletionEvent, +) -> ChatCompletionResponse: # This import should be safe at this point from mistralai.models import ( AssistantMessage, @@ -101,15 +103,19 @@ def wrapper(fn: Callable) -> Callable: return wrapper -def get_mistral_patcher(settings: Optional[IntegrationSettings] = None) -> MultiPatcher: - global _mistral_patcher +def get_mistral_patcher( + settings: IntegrationSettings | None = None, +) -> MultiPatcher | NoOpPatcher: + if settings is None: + settings = IntegrationSettings() + + if not settings.enabled: + return NoOpPatcher() + global _mistral_patcher if _mistral_patcher is not None: return _mistral_patcher - if settings is None: - settings = IntegrationSettings() - base = settings.op_settings chat_complete_settings = base.model_copy( update={"name": base.name or "mistralai.chat.complete"} diff --git a/weave/integrations/notdiamond/tracing.py b/weave/integrations/notdiamond/tracing.py index ae1e5217725..23589719be7 100644 --- a/weave/integrations/notdiamond/tracing.py +++ b/weave/integrations/notdiamond/tracing.py @@ -1,11 +1,13 @@ +from __future__ import annotations + import importlib -from typing import Callable, Optional +from typing import Callable import weave from weave.trace.autopatch import IntegrationSettings, OpSettings -from weave.trace.patcher import MultiPatcher, SymbolPatcher +from weave.trace.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher -_notdiamond_patcher: Optional[MultiPatcher] = None +_notdiamond_patcher: MultiPatcher | None = None def nd_wrapper(settings: OpSettings) -> Callable[[Callable], Callable]: @@ -41,16 +43,18 @@ def _patch_client_op(method_name: str) -> list[SymbolPatcher]: def get_notdiamond_patcher( - settings: Optional[IntegrationSettings] = None, -) -> MultiPatcher: - global _notdiamond_patcher + settings: IntegrationSettings | None = None, +) -> MultiPatcher | NoOpPatcher: + if settings is None: + settings = IntegrationSettings() + + if not settings.enabled: + return NoOpPatcher() + global _notdiamond_patcher if _notdiamond_patcher is not None: return _notdiamond_patcher - if settings is None: - settings = IntegrationSettings() - base = settings.op_settings model_select_settings = base.model_copy( diff --git a/weave/integrations/vertexai/vertexai_sdk.py b/weave/integrations/vertexai/vertexai_sdk.py index f016135d775..94550ce21bd 100644 --- a/weave/integrations/vertexai/vertexai_sdk.py +++ b/weave/integrations/vertexai/vertexai_sdk.py @@ -1,11 +1,13 @@ +from __future__ import annotations + import importlib from functools import wraps -from typing import TYPE_CHECKING, Any, Callable, Optional +from typing import TYPE_CHECKING, Any, Callable import weave from weave.trace.autopatch import IntegrationSettings, OpSettings from weave.trace.op_extensions.accumulator import add_accumulator -from weave.trace.patcher import MultiPatcher, SymbolPatcher +from weave.trace.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher from weave.trace.serialize import dictify from weave.trace.weave_client import Call @@ -13,7 +15,7 @@ from vertexai.generative_models import GenerationResponse -_vertexai_patcher: Optional[MultiPatcher] = None +_vertexai_patcher: MultiPatcher | None = None def vertexai_postprocess_inputs(inputs: dict[str, Any]) -> dict[str, Any]: @@ -29,8 +31,8 @@ def vertexai_postprocess_inputs(inputs: dict[str, Any]) -> dict[str, Any]: def vertexai_accumulator( - acc: Optional["GenerationResponse"], value: "GenerationResponse" -) -> "GenerationResponse": + acc: GenerationResponse | None, value: GenerationResponse +) -> GenerationResponse: from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types from google.cloud.aiplatform_v1beta1.types import ( prediction_service as gapic_prediction_service_types, @@ -66,7 +68,7 @@ def vertexai_accumulator( def vertexai_on_finish( - call: Call, output: Any, exception: Optional[BaseException] + call: Call, output: Any, exception: BaseException | None ) -> None: original_model_name = call.inputs["model_name"] model_name = original_model_name.split("/")[-1] @@ -128,17 +130,17 @@ async def _async_wrapper(*args: Any, **kwargs: Any) -> Any: return wrapper -def get_vertexai_patcher( - settings: Optional[IntegrationSettings] = None, -) -> MultiPatcher: - global _vertexai_patcher +def get_vertexai_patcher(settings: IntegrationSettings | None = None) -> MultiPatcher: + if settings is None: + settings = IntegrationSettings() + + if not settings.enabled: + return NoOpPatcher() + global _vertexai_patcher if _vertexai_patcher is not None: return _vertexai_patcher - if settings is None: - settings = IntegrationSettings() - base = settings.op_settings generate_content_settings = base.model_copy( diff --git a/weave/trace/autopatch.py b/weave/trace/autopatch.py index c79cf3b315d..dd114b2c442 100644 --- a/weave/trace/autopatch.py +++ b/weave/trace/autopatch.py @@ -30,9 +30,10 @@ class IntegrationSettings(BaseModel): class AutopatchSettings(BaseModel): - """Settings for auto-patching integrations.""" + """Settings for auto-patching integrations. - # These will be uncommented as we add support for more integrations. Note that + NOTE: There are no langchain or llamaindex settings here because those integrations + don't make use of the op decorator.""" anthropic: IntegrationSettings = Field(default_factory=IntegrationSettings) cerebras: IntegrationSettings = Field(default_factory=IntegrationSettings) @@ -41,9 +42,7 @@ class AutopatchSettings(BaseModel): google_ai_studio: IntegrationSettings = Field(default_factory=IntegrationSettings) groq: IntegrationSettings = Field(default_factory=IntegrationSettings) instructor: IntegrationSettings = Field(default_factory=IntegrationSettings) - langchain: IntegrationSettings = Field(default_factory=IntegrationSettings) litellm: IntegrationSettings = Field(default_factory=IntegrationSettings) - llamaindex: IntegrationSettings = Field(default_factory=IntegrationSettings) mistral: IntegrationSettings = Field(default_factory=IntegrationSettings) notdiamond: IntegrationSettings = Field(default_factory=IntegrationSettings) openai: IntegrationSettings = Field(default_factory=IntegrationSettings) From 13554e3b54858d2da23f9bd505478adf05402a33 Mon Sep 17 00:00:00 2001 From: Andrew Truong Date: Wed, 11 Dec 2024 20:04:29 -0500 Subject: [PATCH 5/5] test --- weave/integrations/vertexai/vertexai_sdk.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/weave/integrations/vertexai/vertexai_sdk.py b/weave/integrations/vertexai/vertexai_sdk.py index 94550ce21bd..a64620cbe5f 100644 --- a/weave/integrations/vertexai/vertexai_sdk.py +++ b/weave/integrations/vertexai/vertexai_sdk.py @@ -130,7 +130,9 @@ async def _async_wrapper(*args: Any, **kwargs: Any) -> Any: return wrapper -def get_vertexai_patcher(settings: IntegrationSettings | None = None) -> MultiPatcher: +def get_vertexai_patcher( + settings: IntegrationSettings | None = None, +) -> MultiPatcher | NoOpPatcher: if settings is None: settings = IntegrationSettings()