Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewtruong committed Dec 12, 2024
1 parent 781edd3 commit 435ca80
Showing 2 changed files with 93 additions and 46 deletions.
129 changes: 88 additions & 41 deletions weave/integrations/google_ai_studio/google_ai_studio_sdk.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
from __future__ import annotations

import importlib
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, Optional
from typing import TYPE_CHECKING, Any, Callable

import weave
from weave.trace.autopatch import IntegrationSettings, OpSettings
from weave.trace.op_extensions.accumulator import add_accumulator
from weave.trace.patcher import MultiPatcher, SymbolPatcher
from weave.trace.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher
from weave.trace.serialize import dictify
from weave.trace.weave_client import Call

if TYPE_CHECKING:
from google.generativeai.types.generation_types import GenerateContentResponse

_google_genai_patcher: MultiPatcher | None = None


def gemini_postprocess_inputs(inputs: dict[str, Any]) -> dict[str, Any]:
if "self" in inputs:
@@ -19,8 +24,8 @@ def gemini_postprocess_inputs(inputs: dict[str, Any]) -> dict[str, Any]:


def gemini_accumulator(
acc: Optional["GenerateContentResponse"], value: "GenerateContentResponse"
) -> "GenerateContentResponse":
acc: GenerateContentResponse | None, value: GenerateContentResponse
) -> GenerateContentResponse:
if acc is None:
return value

@@ -64,9 +69,7 @@ def gemini_accumulator(
return acc


def gemini_on_finish(
call: Call, output: Any, exception: Optional[BaseException]
) -> None:
def gemini_on_finish(call: Call, output: Any, exception: BaseException | None) -> None:
if "model_name" in call.inputs["self"]:
original_model_name = call.inputs["self"]["model_name"]
elif "model" in call.inputs["self"]:
@@ -89,10 +92,13 @@ def gemini_on_finish(
call.summary.update(summary_update)


def gemini_wrapper_sync(name: str) -> Callable[[Callable], Callable]:
def gemini_wrapper_sync(settings: OpSettings) -> Callable[[Callable], Callable]:
def wrapper(fn: Callable) -> Callable:
op = weave.op(postprocess_inputs=gemini_postprocess_inputs)(fn)
op.name = name # type: ignore
op_kwargs = settings.model_dump()
if not op_kwargs.get("postprocess_inputs"):
op_kwargs["postprocess_inputs"] = gemini_postprocess_inputs

op = weave.op(fn, **op_kwargs)
op._set_on_finish_handler(gemini_on_finish)
return add_accumulator(
op, # type: ignore
@@ -104,7 +110,7 @@ def wrapper(fn: Callable) -> Callable:
return wrapper


def gemini_wrapper_async(name: str) -> Callable[[Callable], Callable]:
def gemini_wrapper_async(settings: OpSettings) -> Callable[[Callable], Callable]:
def wrapper(fn: Callable) -> Callable:
def _fn_wrapper(fn: Callable) -> Callable:
@wraps(fn)
@@ -113,9 +119,11 @@ async def _async_wrapper(*args: Any, **kwargs: Any) -> Any:

return _async_wrapper

"We need to do this so we can check if `stream` is used"
op = weave.op(postprocess_inputs=gemini_postprocess_inputs)(_fn_wrapper(fn))
op.name = name # type: ignore
op_kwargs = settings.model_dump()
if not op_kwargs.get("postprocess_inputs"):
op_kwargs["postprocess_inputs"] = gemini_postprocess_inputs

op = weave.op(_fn_wrapper(fn), **op_kwargs)
op._set_on_finish_handler(gemini_on_finish)
return add_accumulator(
op, # type: ignore
@@ -127,33 +135,72 @@ async def _async_wrapper(*args: Any, **kwargs: Any) -> Any:
return wrapper


google_genai_patcher = MultiPatcher(
[
SymbolPatcher(
lambda: importlib.import_module("google.generativeai.generative_models"),
"GenerativeModel.generate_content",
gemini_wrapper_sync(
name="google.generativeai.GenerativeModel.generate_content"
def get_google_genai_patcher(
settings: IntegrationSettings | None = None,
) -> MultiPatcher | NoOpPatcher:
if settings is None:
settings = IntegrationSettings()

if not settings.enabled:
return NoOpPatcher()

global _google_genai_patcher
if _google_genai_patcher is not None:
return _google_genai_patcher

base = settings.op_settings

generate_content_settings = base.model_copy(
update={
"name": base.name or "google.generativeai.GenerativeModel.generate_content"
}
)
generate_content_async_settings = base.model_copy(
update={
"name": base.name
or "google.generativeai.GenerativeModel.generate_content_async"
}
)
send_message_settings = base.model_copy(
update={"name": base.name or "google.generativeai.ChatSession.send_message"}
)
send_message_async_settings = base.model_copy(
update={
"name": base.name or "google.generativeai.ChatSession.send_message_async"
}
)

_google_genai_patcher = MultiPatcher(
[
SymbolPatcher(
lambda: importlib.import_module(
"google.generativeai.generative_models"
),
"GenerativeModel.generate_content",
gemini_wrapper_sync(generate_content_settings),
),
SymbolPatcher(
lambda: importlib.import_module(
"google.generativeai.generative_models"
),
"GenerativeModel.generate_content_async",
gemini_wrapper_async(generate_content_async_settings),
),
),
SymbolPatcher(
lambda: importlib.import_module("google.generativeai.generative_models"),
"GenerativeModel.generate_content_async",
gemini_wrapper_async(
name="google.generativeai.GenerativeModel.generate_content_async"
SymbolPatcher(
lambda: importlib.import_module(
"google.generativeai.generative_models"
),
"ChatSession.send_message",
gemini_wrapper_sync(send_message_settings),
),
),
SymbolPatcher(
lambda: importlib.import_module("google.generativeai.generative_models"),
"ChatSession.send_message",
gemini_wrapper_sync(name="google.generativeai.ChatSession.send_message"),
),
SymbolPatcher(
lambda: importlib.import_module("google.generativeai.generative_models"),
"ChatSession.send_message_async",
gemini_wrapper_async(
name="google.generativeai.ChatSession.send_message_async"
SymbolPatcher(
lambda: importlib.import_module(
"google.generativeai.generative_models"
),
"ChatSession.send_message_async",
gemini_wrapper_async(send_message_async_settings),
),
),
]
)
]
)

return _google_genai_patcher
10 changes: 5 additions & 5 deletions weave/trace/autopatch.py
Original file line number Diff line number Diff line change
@@ -38,7 +38,7 @@ class AutopatchSettings(BaseModel):
cerebras: IntegrationSettings = Field(default_factory=IntegrationSettings)
cohere: IntegrationSettings = Field(default_factory=IntegrationSettings)
dspy: IntegrationSettings = Field(default_factory=IntegrationSettings)
# google_ai_studio: IntegrationSettings = Field(default_factory=IntegrationSettings)
google_ai_studio: IntegrationSettings = Field(default_factory=IntegrationSettings)
# groq: IntegrationSettings = Field(default_factory=IntegrationSettings)
# instructor: IntegrationSettings = Field(default_factory=IntegrationSettings)
# langchain: IntegrationSettings = Field(default_factory=IntegrationSettings)
@@ -57,7 +57,7 @@ def autopatch(settings: Optional[AutopatchSettings] = None) -> None:
from weave.integrations.cohere.cohere_sdk import get_cohere_patcher
from weave.integrations.dspy.dspy_sdk import get_dspy_patcher
from weave.integrations.google_ai_studio.google_ai_studio_sdk import (
google_genai_patcher,
get_google_genai_patcher,
)
from weave.integrations.groq.groq_sdk import groq_patcher
from weave.integrations.instructor.instructor_sdk import instructor_patcher
@@ -83,7 +83,7 @@ def autopatch(settings: Optional[AutopatchSettings] = None) -> None:
get_dspy_patcher(settings.dspy).attempt_patch()
get_cerebras_patcher(settings.cerebras).attempt_patch()
get_cohere_patcher(settings.cohere).attempt_patch()
google_genai_patcher.attempt_patch()
get_google_genai_patcher(settings.google_ai_studio).attempt_patch()
notdiamond_patcher.attempt_patch()
vertexai_patcher.attempt_patch()

@@ -94,7 +94,7 @@ def reset_autopatch() -> None:
from weave.integrations.cohere.cohere_sdk import get_cohere_patcher
from weave.integrations.dspy.dspy_sdk import get_dspy_patcher
from weave.integrations.google_ai_studio.google_ai_studio_sdk import (
google_genai_patcher,
get_google_genai_patcher,
)
from weave.integrations.groq.groq_sdk import groq_patcher
from weave.integrations.instructor.instructor_sdk import instructor_patcher
@@ -117,6 +117,6 @@ def reset_autopatch() -> None:
get_dspy_patcher().undo_patch()
get_cerebras_patcher().undo_patch()
get_cohere_patcher().undo_patch()
google_genai_patcher.undo_patch()
get_google_genai_patcher().undo_patch()
notdiamond_patcher.undo_patch()
vertexai_patcher.undo_patch()

0 comments on commit 435ca80

Please sign in to comment.