From 6d4ae668c65915373536dff67b6a01061e93b61f Mon Sep 17 00:00:00 2001 From: Andrew Truong Date: Tue, 10 Dec 2024 15:11:59 -0500 Subject: [PATCH] test --- weave/integrations/vertexai/vertexai_sdk.py | 119 ++++++++++++++------ weave/trace/autopatch.py | 8 +- 2 files changed, 86 insertions(+), 41 deletions(-) diff --git a/weave/integrations/vertexai/vertexai_sdk.py b/weave/integrations/vertexai/vertexai_sdk.py index c2f7a9906c7..016f0917b87 100644 --- a/weave/integrations/vertexai/vertexai_sdk.py +++ b/weave/integrations/vertexai/vertexai_sdk.py @@ -1,8 +1,10 @@ +import dataclasses import importlib from functools import wraps from typing import TYPE_CHECKING, Any, Callable, Optional import weave +from weave.trace.autopatch import IntegrationSettings, OpSettings from weave.trace.op_extensions.accumulator import add_accumulator from weave.trace.patcher import MultiPatcher, SymbolPatcher from weave.trace.serialize import dictify @@ -11,6 +13,8 @@ if TYPE_CHECKING: from vertexai.generative_models import GenerationResponse +_vertexai_patcher: Optional[MultiPatcher] = None + def vertexai_postprocess_inputs(inputs: dict[str, Any]) -> dict[str, Any]: if "self" in inputs: @@ -81,10 +85,13 @@ def vertexai_on_finish( call.summary.update(summary_update) -def vertexai_wrapper_sync(name: str) -> Callable[[Callable], Callable]: +def vertexai_wrapper_sync(settings: OpSettings) -> Callable[[Callable], Callable]: def wrapper(fn: Callable) -> Callable: - op = weave.op(postprocess_inputs=vertexai_postprocess_inputs)(fn) - op.name = name # type: ignore + op_kwargs = dataclasses.asdict(settings) + if not settings.postprocess_inputs: + op_kwargs["postprocess_inputs"] = vertexai_postprocess_inputs + + op = weave.op(fn, **op_kwargs) op._set_on_finish_handler(vertexai_on_finish) return add_accumulator( op, # type: ignore @@ -96,7 +103,7 @@ def wrapper(fn: Callable) -> Callable: return wrapper -def vertexai_wrapper_async(name: str) -> Callable[[Callable], Callable]: +def vertexai_wrapper_async(settings: OpSettings) -> Callable[[Callable], Callable]: def wrapper(fn: Callable) -> Callable: def _fn_wrapper(fn: Callable) -> Callable: @wraps(fn) @@ -105,9 +112,11 @@ async def _async_wrapper(*args: Any, **kwargs: Any) -> Any: return _async_wrapper - "We need to do this so we can check if `stream` is used" - op = weave.op(postprocess_inputs=vertexai_postprocess_inputs)(_fn_wrapper(fn)) - op.name = name # type: ignore + op_kwargs = dataclasses.asdict(settings) + if not settings.postprocess_inputs: + op_kwargs["postprocess_inputs"] = vertexai_postprocess_inputs + + op = weave.op(_fn_wrapper(fn), **op_kwargs) op._set_on_finish_handler(vertexai_on_finish) return add_accumulator( op, # type: ignore @@ -119,34 +128,70 @@ async def _async_wrapper(*args: Any, **kwargs: Any) -> Any: return wrapper -vertexai_patcher = MultiPatcher( - [ - SymbolPatcher( - lambda: importlib.import_module("vertexai.generative_models"), - "GenerativeModel.generate_content", - vertexai_wrapper_sync(name="vertexai.GenerativeModel.generate_content"), - ), - SymbolPatcher( - lambda: importlib.import_module("vertexai.generative_models"), - "GenerativeModel.generate_content_async", - vertexai_wrapper_async( - name="vertexai.GenerativeModel.generate_content_async" +def get_vertexai_patcher( + settings: Optional[IntegrationSettings] = None, +) -> MultiPatcher: + global _vertexai_patcher + + if _vertexai_patcher is not None: + return _vertexai_patcher + + if settings is None: + settings = IntegrationSettings() + + base = settings.op_settings + + generative_model_generate_content_settings = dataclasses.replace( + base, + name=base.name or "vertexai.GenerativeModel.generate_content", + ) + generative_model_generate_content_async_settings = dataclasses.replace( + base, + name=base.name or "vertexai.GenerativeModel.generate_content_async", + ) + chat_session_send_message_settings = dataclasses.replace( + base, + name=base.name or "vertexai.ChatSession.send_message", + ) + chat_session_send_message_async_settings = dataclasses.replace( + base, + name=base.name or "vertexai.ChatSession.send_message_async", + ) + image_generation_model_generate_images_settings = dataclasses.replace( + base, + name=base.name or "vertexai.ImageGenerationModel.generate_images", + ) + + _vertexai_patcher = MultiPatcher( + [ + SymbolPatcher( + lambda: importlib.import_module("vertexai.generative_models"), + "GenerativeModel.generate_content", + vertexai_wrapper_sync(generative_model_generate_content_settings), + ), + SymbolPatcher( + lambda: importlib.import_module("vertexai.generative_models"), + "GenerativeModel.generate_content_async", + vertexai_wrapper_async( + generative_model_generate_content_async_settings + ), + ), + SymbolPatcher( + lambda: importlib.import_module("vertexai.generative_models"), + "ChatSession.send_message", + vertexai_wrapper_sync(chat_session_send_message_settings), ), - ), - SymbolPatcher( - lambda: importlib.import_module("vertexai.generative_models"), - "ChatSession.send_message", - vertexai_wrapper_sync(name="vertexai.ChatSession.send_message"), - ), - SymbolPatcher( - lambda: importlib.import_module("vertexai.generative_models"), - "ChatSession.send_message_async", - vertexai_wrapper_async(name="vertexai.ChatSession.send_message_async"), - ), - SymbolPatcher( - lambda: importlib.import_module("vertexai.preview.vision_models"), - "ImageGenerationModel.generate_images", - vertexai_wrapper_sync(name="vertexai.ImageGenerationModel.generate_images"), - ), - ] -) + SymbolPatcher( + lambda: importlib.import_module("vertexai.generative_models"), + "ChatSession.send_message_async", + vertexai_wrapper_async(chat_session_send_message_async_settings), + ), + SymbolPatcher( + lambda: importlib.import_module("vertexai.preview.vision_models"), + "ImageGenerationModel.generate_images", + vertexai_wrapper_sync(image_generation_model_generate_images_settings), + ), + ] + ) + + return _vertexai_patcher diff --git a/weave/trace/autopatch.py b/weave/trace/autopatch.py index a9f1f6d3dfa..8500581d06e 100644 --- a/weave/trace/autopatch.py +++ b/weave/trace/autopatch.py @@ -28,7 +28,7 @@ def autopatch(settings: AutopatchSettings | None = None) -> None: from weave.integrations.mistral import get_mistral_patcher from weave.integrations.notdiamond.tracing import get_notdiamond_patcher from weave.integrations.openai.openai_sdk import get_openai_patcher - from weave.integrations.vertexai.vertexai_sdk import vertexai_patcher + from weave.integrations.vertexai.vertexai_sdk import get_vertexai_patcher if settings is None: settings = AutopatchSettings() @@ -46,7 +46,7 @@ def autopatch(settings: AutopatchSettings | None = None) -> None: get_cohere_patcher(settings.cohere).attempt_patch() get_google_genai_patcher(settings.google_ai_studio).attempt_patch() get_notdiamond_patcher(settings.notdiamond).attempt_patch() - vertexai_patcher.attempt_patch() + get_vertexai_patcher(settings.vertexai).attempt_patch() def reset_autopatch() -> None: @@ -65,7 +65,7 @@ def reset_autopatch() -> None: from weave.integrations.mistral import get_mistral_patcher from weave.integrations.notdiamond.tracing import get_notdiamond_patcher from weave.integrations.openai.openai_sdk import get_openai_patcher - from weave.integrations.vertexai.vertexai_sdk import vertexai_patcher + from weave.integrations.vertexai.vertexai_sdk import get_vertexai_patcher get_openai_patcher().undo_patch() get_mistral_patcher().undo_patch() @@ -80,7 +80,7 @@ def reset_autopatch() -> None: get_cohere_patcher().undo_patch() get_google_genai_patcher().undo_patch() get_notdiamond_patcher().undo_patch() - vertexai_patcher.undo_patch() + get_vertexai_patcher().undo_patch() @dataclass