Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewtruong committed Dec 12, 2024
1 parent 3ebdb4b commit 3571f93
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 47 deletions.
126 changes: 84 additions & 42 deletions weave/integrations/vertexai/vertexai_sdk.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
from __future__ import annotations

import importlib
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, Optional
from typing import TYPE_CHECKING, Any, Callable

import weave
from weave.trace.autopatch import IntegrationSettings, OpSettings
from weave.trace.op_extensions.accumulator import add_accumulator
from weave.trace.patcher import MultiPatcher, SymbolPatcher
from weave.trace.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher
from weave.trace.serialize import dictify
from weave.trace.weave_client import Call

if TYPE_CHECKING:
from vertexai.generative_models import GenerationResponse


_vertexai_patcher: MultiPatcher | None = None


def vertexai_postprocess_inputs(inputs: dict[str, Any]) -> dict[str, Any]:
if "self" in inputs:
model_name = (
Expand All @@ -25,8 +31,8 @@ def vertexai_postprocess_inputs(inputs: dict[str, Any]) -> dict[str, Any]:


def vertexai_accumulator(
acc: Optional["GenerationResponse"], value: "GenerationResponse"
) -> "GenerationResponse":
acc: GenerationResponse | None, value: GenerationResponse
) -> GenerationResponse:
from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types
from google.cloud.aiplatform_v1beta1.types import (
prediction_service as gapic_prediction_service_types,
Expand Down Expand Up @@ -62,7 +68,7 @@ def vertexai_accumulator(


def vertexai_on_finish(
call: Call, output: Any, exception: Optional[BaseException]
call: Call, output: Any, exception: BaseException | None
) -> None:
original_model_name = call.inputs["model_name"]
model_name = original_model_name.split("/")[-1]
Expand All @@ -81,10 +87,13 @@ def vertexai_on_finish(
call.summary.update(summary_update)


def vertexai_wrapper_sync(name: str) -> Callable[[Callable], Callable]:
def vertexai_wrapper_sync(settings: OpSettings) -> Callable[[Callable], Callable]:
def wrapper(fn: Callable) -> Callable:
op = weave.op(postprocess_inputs=vertexai_postprocess_inputs)(fn)
op.name = name # type: ignore
op_kwargs = settings.model_copy()
if not op_kwargs.get("postprocess_inputs"):
op_kwargs["postprocess_inputs"] = vertexai_postprocess_inputs

op = weave.op(fn, **op_kwargs)
op._set_on_finish_handler(vertexai_on_finish)
return add_accumulator(
op, # type: ignore
Expand All @@ -96,7 +105,7 @@ def wrapper(fn: Callable) -> Callable:
return wrapper


def vertexai_wrapper_async(name: str) -> Callable[[Callable], Callable]:
def vertexai_wrapper_async(settings: OpSettings) -> Callable[[Callable], Callable]:
def wrapper(fn: Callable) -> Callable:
def _fn_wrapper(fn: Callable) -> Callable:
@wraps(fn)
Expand All @@ -105,9 +114,11 @@ async def _async_wrapper(*args: Any, **kwargs: Any) -> Any:

return _async_wrapper

"We need to do this so we can check if `stream` is used"
op = weave.op(postprocess_inputs=vertexai_postprocess_inputs)(_fn_wrapper(fn))
op.name = name # type: ignore
op_kwargs = settings.model_copy()
if not op_kwargs.get("postprocess_inputs"):
op_kwargs["postprocess_inputs"] = vertexai_postprocess_inputs

op = weave.op(_fn_wrapper(fn), **op_kwargs)
op._set_on_finish_handler(vertexai_on_finish)
return add_accumulator(
op, # type: ignore
Expand All @@ -119,34 +130,65 @@ async def _async_wrapper(*args: Any, **kwargs: Any) -> Any:
return wrapper


vertexai_patcher = MultiPatcher(
[
SymbolPatcher(
lambda: importlib.import_module("vertexai.generative_models"),
"GenerativeModel.generate_content",
vertexai_wrapper_sync(name="vertexai.GenerativeModel.generate_content"),
),
SymbolPatcher(
lambda: importlib.import_module("vertexai.generative_models"),
"GenerativeModel.generate_content_async",
vertexai_wrapper_async(
name="vertexai.GenerativeModel.generate_content_async"
def get_vertexai_patcher(
settings: IntegrationSettings | None = None,
) -> MultiPatcher | NoOpPatcher:
if settings is None:
settings = IntegrationSettings()

if not settings.enabled:
return NoOpPatcher()

global _vertexai_patcher
if _vertexai_patcher is not None:
return _vertexai_patcher

base = settings.op_settings

generate_content_settings = base.model_copy(
update={"name": base.name or "vertexai.GenerativeModel.generate_content"}
)
generate_content_async_settings = base.model_copy(
update={"name": base.name or "vertexai.GenerativeModel.generate_content_async"}
)
send_message_settings = base.model_copy(
update={"name": base.name or "vertexai.ChatSession.send_message"}
)
send_message_async_settings = base.model_copy(
update={"name": base.name or "vertexai.ChatSession.send_message_async"}
)
generate_images_settings = base.model_copy(
update={"name": base.name or "vertexai.ImageGenerationModel.generate_images"}
)

_vertexai_patcher = MultiPatcher(
[
SymbolPatcher(
lambda: importlib.import_module("vertexai.generative_models"),
"GenerativeModel.generate_content",
vertexai_wrapper_sync(generate_content_settings),
),
SymbolPatcher(
lambda: importlib.import_module("vertexai.generative_models"),
"GenerativeModel.generate_content_async",
vertexai_wrapper_async(generate_content_async_settings),
),
),
SymbolPatcher(
lambda: importlib.import_module("vertexai.generative_models"),
"ChatSession.send_message",
vertexai_wrapper_sync(name="vertexai.ChatSession.send_message"),
),
SymbolPatcher(
lambda: importlib.import_module("vertexai.generative_models"),
"ChatSession.send_message_async",
vertexai_wrapper_async(name="vertexai.ChatSession.send_message_async"),
),
SymbolPatcher(
lambda: importlib.import_module("vertexai.preview.vision_models"),
"ImageGenerationModel.generate_images",
vertexai_wrapper_sync(name="vertexai.ImageGenerationModel.generate_images"),
),
]
)
SymbolPatcher(
lambda: importlib.import_module("vertexai.generative_models"),
"ChatSession.send_message",
vertexai_wrapper_sync(send_message_settings),
),
SymbolPatcher(
lambda: importlib.import_module("vertexai.generative_models"),
"ChatSession.send_message_async",
vertexai_wrapper_async(send_message_async_settings),
),
SymbolPatcher(
lambda: importlib.import_module("vertexai.preview.vision_models"),
"ImageGenerationModel.generate_images",
vertexai_wrapper_sync(generate_images_settings),
),
]
)

return _vertexai_patcher
10 changes: 5 additions & 5 deletions weave/trace/autopatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class AutopatchSettings(BaseModel):
mistral: IntegrationSettings = Field(default_factory=IntegrationSettings)
notdiamond: IntegrationSettings = Field(default_factory=IntegrationSettings)
openai: IntegrationSettings = Field(default_factory=IntegrationSettings)
# vertexai: IntegrationSettings = Field(default_factory=IntegrationSettings)
vertexai: IntegrationSettings = Field(default_factory=IntegrationSettings)


@validate_call
Expand All @@ -65,7 +65,7 @@ def autopatch(settings: Optional[AutopatchSettings] = None) -> None:
from weave.integrations.mistral import get_mistral_patcher
from weave.integrations.notdiamond.tracing import get_notdiamond_patcher
from weave.integrations.openai.openai_sdk import get_openai_patcher
from weave.integrations.vertexai.vertexai_sdk import vertexai_patcher
from weave.integrations.vertexai.vertexai_sdk import get_vertexai_patcher

if settings is None:
settings = AutopatchSettings()
Expand All @@ -83,7 +83,7 @@ def autopatch(settings: Optional[AutopatchSettings] = None) -> None:
get_cohere_patcher(settings.cohere).attempt_patch()
get_google_genai_patcher(settings.google_ai_studio).attempt_patch()
get_notdiamond_patcher(settings.notdiamond).attempt_patch()
vertexai_patcher.attempt_patch()
get_vertexai_patcher(settings.vertexai).attempt_patch()


def reset_autopatch() -> None:
Expand All @@ -102,7 +102,7 @@ def reset_autopatch() -> None:
from weave.integrations.mistral import get_mistral_patcher
from weave.integrations.notdiamond.tracing import get_notdiamond_patcher
from weave.integrations.openai.openai_sdk import get_openai_patcher
from weave.integrations.vertexai.vertexai_sdk import vertexai_patcher
from weave.integrations.vertexai.vertexai_sdk import get_vertexai_patcher

get_openai_patcher().undo_patch()
get_mistral_patcher().undo_patch()
Expand All @@ -117,4 +117,4 @@ def reset_autopatch() -> None:
get_cohere_patcher().undo_patch()
get_google_genai_patcher().undo_patch()
get_notdiamond_patcher().undo_patch()
vertexai_patcher.undo_patch()
get_vertexai_patcher().undo_patch()

0 comments on commit 3571f93

Please sign in to comment.