diff --git a/weave/integrations/gemini/gemini_sdk.py b/weave/integrations/gemini/gemini_sdk.py index a3e5466224f..8a0a66db793 100644 --- a/weave/integrations/gemini/gemini_sdk.py +++ b/weave/integrations/gemini/gemini_sdk.py @@ -37,12 +37,7 @@ def gemini_wrapper(name: str) -> Callable[[Callable], Callable]: def wrapper(fn: Callable) -> Callable: op = weave.op()(fn) op.name = name # type: ignore - # return op - return add_accumulator( - op, # type: ignore - make_accumulator=lambda inputs: gemini_accumulator, - should_accumulate=should_use_accumulator, - ) + return op return wrapper @@ -54,5 +49,32 @@ def wrapper(fn: Callable) -> Callable: "GenerativeModel.generate_content", gemini_wrapper(name="google.generativeai.GenerativeModel.generate_content"), ), + SymbolPatcher( + lambda: importlib.import_module( + "google.generativeai.types.generation_types" + ), + "GenerateContentResponse.from_response", + gemini_wrapper( + name="google.generativeai.types.generation_types.GenerateContentResponse.from_response" + ), + ), + SymbolPatcher( + lambda: importlib.import_module( + "google.generativeai.types.generation_types" + ), + "GenerateContentResponse.from_iterator", + gemini_wrapper( + name="google.generativeai.types.generation_types.GenerateContentResponse.from_iterator" + ), + ), + SymbolPatcher( + lambda: importlib.import_module( + "google.ai.generativelanguage_v1beta.services.generative_service.client" + ), + "GenerativeServiceClient.generate_content", + gemini_wrapper( + name="google.ai.generativelanguage_v1beta.services.generative_service.client.GenerativeServiceClient.generate_content" + ), + ), ] )