Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewtruong committed Dec 10, 2024
1 parent 222a02a commit 6d4ae66
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 41 deletions.
119 changes: 82 additions & 37 deletions weave/integrations/vertexai/vertexai_sdk.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import dataclasses
import importlib
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, Optional

import weave
from weave.trace.autopatch import IntegrationSettings, OpSettings
from weave.trace.op_extensions.accumulator import add_accumulator
from weave.trace.patcher import MultiPatcher, SymbolPatcher
from weave.trace.serialize import dictify
Expand All @@ -11,6 +13,8 @@
if TYPE_CHECKING:
from vertexai.generative_models import GenerationResponse

_vertexai_patcher: Optional[MultiPatcher] = None


def vertexai_postprocess_inputs(inputs: dict[str, Any]) -> dict[str, Any]:
if "self" in inputs:
Expand Down Expand Up @@ -81,10 +85,13 @@ def vertexai_on_finish(
call.summary.update(summary_update)


def vertexai_wrapper_sync(name: str) -> Callable[[Callable], Callable]:
def vertexai_wrapper_sync(settings: OpSettings) -> Callable[[Callable], Callable]:
def wrapper(fn: Callable) -> Callable:
op = weave.op(postprocess_inputs=vertexai_postprocess_inputs)(fn)
op.name = name # type: ignore
op_kwargs = dataclasses.asdict(settings)
if not settings.postprocess_inputs:
op_kwargs["postprocess_inputs"] = vertexai_postprocess_inputs

op = weave.op(fn, **op_kwargs)
op._set_on_finish_handler(vertexai_on_finish)
return add_accumulator(
op, # type: ignore
Expand All @@ -96,7 +103,7 @@ def wrapper(fn: Callable) -> Callable:
return wrapper


def vertexai_wrapper_async(name: str) -> Callable[[Callable], Callable]:
def vertexai_wrapper_async(settings: OpSettings) -> Callable[[Callable], Callable]:
def wrapper(fn: Callable) -> Callable:
def _fn_wrapper(fn: Callable) -> Callable:
@wraps(fn)
Expand All @@ -105,9 +112,11 @@ async def _async_wrapper(*args: Any, **kwargs: Any) -> Any:

return _async_wrapper

"We need to do this so we can check if `stream` is used"
op = weave.op(postprocess_inputs=vertexai_postprocess_inputs)(_fn_wrapper(fn))
op.name = name # type: ignore
op_kwargs = dataclasses.asdict(settings)
if not settings.postprocess_inputs:
op_kwargs["postprocess_inputs"] = vertexai_postprocess_inputs

op = weave.op(_fn_wrapper(fn), **op_kwargs)
op._set_on_finish_handler(vertexai_on_finish)
return add_accumulator(
op, # type: ignore
Expand All @@ -119,34 +128,70 @@ async def _async_wrapper(*args: Any, **kwargs: Any) -> Any:
return wrapper


vertexai_patcher = MultiPatcher(
[
SymbolPatcher(
lambda: importlib.import_module("vertexai.generative_models"),
"GenerativeModel.generate_content",
vertexai_wrapper_sync(name="vertexai.GenerativeModel.generate_content"),
),
SymbolPatcher(
lambda: importlib.import_module("vertexai.generative_models"),
"GenerativeModel.generate_content_async",
vertexai_wrapper_async(
name="vertexai.GenerativeModel.generate_content_async"
def get_vertexai_patcher(
settings: Optional[IntegrationSettings] = None,
) -> MultiPatcher:
global _vertexai_patcher

if _vertexai_patcher is not None:
return _vertexai_patcher

if settings is None:
settings = IntegrationSettings()

base = settings.op_settings

generative_model_generate_content_settings = dataclasses.replace(
base,
name=base.name or "vertexai.GenerativeModel.generate_content",
)
generative_model_generate_content_async_settings = dataclasses.replace(
base,
name=base.name or "vertexai.GenerativeModel.generate_content_async",
)
chat_session_send_message_settings = dataclasses.replace(
base,
name=base.name or "vertexai.ChatSession.send_message",
)
chat_session_send_message_async_settings = dataclasses.replace(
base,
name=base.name or "vertexai.ChatSession.send_message_async",
)
image_generation_model_generate_images_settings = dataclasses.replace(
base,
name=base.name or "vertexai.ImageGenerationModel.generate_images",
)

_vertexai_patcher = MultiPatcher(
[
SymbolPatcher(
lambda: importlib.import_module("vertexai.generative_models"),
"GenerativeModel.generate_content",
vertexai_wrapper_sync(generative_model_generate_content_settings),
),
SymbolPatcher(
lambda: importlib.import_module("vertexai.generative_models"),
"GenerativeModel.generate_content_async",
vertexai_wrapper_async(
generative_model_generate_content_async_settings
),
),
SymbolPatcher(
lambda: importlib.import_module("vertexai.generative_models"),
"ChatSession.send_message",
vertexai_wrapper_sync(chat_session_send_message_settings),
),
),
SymbolPatcher(
lambda: importlib.import_module("vertexai.generative_models"),
"ChatSession.send_message",
vertexai_wrapper_sync(name="vertexai.ChatSession.send_message"),
),
SymbolPatcher(
lambda: importlib.import_module("vertexai.generative_models"),
"ChatSession.send_message_async",
vertexai_wrapper_async(name="vertexai.ChatSession.send_message_async"),
),
SymbolPatcher(
lambda: importlib.import_module("vertexai.preview.vision_models"),
"ImageGenerationModel.generate_images",
vertexai_wrapper_sync(name="vertexai.ImageGenerationModel.generate_images"),
),
]
)
SymbolPatcher(
lambda: importlib.import_module("vertexai.generative_models"),
"ChatSession.send_message_async",
vertexai_wrapper_async(chat_session_send_message_async_settings),
),
SymbolPatcher(
lambda: importlib.import_module("vertexai.preview.vision_models"),
"ImageGenerationModel.generate_images",
vertexai_wrapper_sync(image_generation_model_generate_images_settings),
),
]
)

return _vertexai_patcher
8 changes: 4 additions & 4 deletions weave/trace/autopatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def autopatch(settings: AutopatchSettings | None = None) -> None:
from weave.integrations.mistral import get_mistral_patcher
from weave.integrations.notdiamond.tracing import get_notdiamond_patcher
from weave.integrations.openai.openai_sdk import get_openai_patcher
from weave.integrations.vertexai.vertexai_sdk import vertexai_patcher
from weave.integrations.vertexai.vertexai_sdk import get_vertexai_patcher

if settings is None:
settings = AutopatchSettings()
Expand All @@ -46,7 +46,7 @@ def autopatch(settings: AutopatchSettings | None = None) -> None:
get_cohere_patcher(settings.cohere).attempt_patch()
get_google_genai_patcher(settings.google_ai_studio).attempt_patch()
get_notdiamond_patcher(settings.notdiamond).attempt_patch()
vertexai_patcher.attempt_patch()
get_vertexai_patcher(settings.vertexai).attempt_patch()


def reset_autopatch() -> None:
Expand All @@ -65,7 +65,7 @@ def reset_autopatch() -> None:
from weave.integrations.mistral import get_mistral_patcher
from weave.integrations.notdiamond.tracing import get_notdiamond_patcher
from weave.integrations.openai.openai_sdk import get_openai_patcher
from weave.integrations.vertexai.vertexai_sdk import vertexai_patcher
from weave.integrations.vertexai.vertexai_sdk import get_vertexai_patcher

get_openai_patcher().undo_patch()
get_mistral_patcher().undo_patch()
Expand All @@ -80,7 +80,7 @@ def reset_autopatch() -> None:
get_cohere_patcher().undo_patch()
get_google_genai_patcher().undo_patch()
get_notdiamond_patcher().undo_patch()
vertexai_patcher.undo_patch()
get_vertexai_patcher().undo_patch()


@dataclass
Expand Down

0 comments on commit 6d4ae66

Please sign in to comment.