diff --git a/weave/integrations/gemini/gemini_sdk.py b/weave/integrations/gemini/gemini_sdk.py index 8a0a66db793..a3523f50d18 100644 --- a/weave/integrations/gemini/gemini_sdk.py +++ b/weave/integrations/gemini/gemini_sdk.py @@ -50,21 +50,24 @@ def wrapper(fn: Callable) -> Callable: gemini_wrapper(name="google.generativeai.GenerativeModel.generate_content"), ), SymbolPatcher( - lambda: importlib.import_module( - "google.generativeai.types.generation_types" - ), - "GenerateContentResponse.from_response", + lambda: importlib.import_module("google.generativeai"), + "GenerativeModel.generate_content_async", gemini_wrapper( - name="google.generativeai.types.generation_types.GenerateContentResponse.from_response" + name="google.generativeai.GenerativeModel.generate_content_async" ), ), + SymbolPatcher( + lambda: importlib.import_module("google.generativeai"), + "GenerativeModel.generate_content", + gemini_wrapper(name="google.generativeai.GenerativeModel.start_chat"), + ), SymbolPatcher( lambda: importlib.import_module( "google.generativeai.types.generation_types" ), - "GenerateContentResponse.from_iterator", + "GenerateContentResponse.from_response", gemini_wrapper( - name="google.generativeai.types.generation_types.GenerateContentResponse.from_iterator" + name="google.generativeai.types.generation_types.GenerateContentResponse.from_response" ), ), SymbolPatcher(