diff --git a/tests/integrations/google_ai_studio/google_ai_studio_test.py b/tests/integrations/google_ai_studio/google_ai_studio_test.py index 11e947e6255..dd071136d47 100644 --- a/tests/integrations/google_ai_studio/google_ai_studio_test.py +++ b/tests/integrations/google_ai_studio/google_ai_studio_test.py @@ -148,3 +148,35 @@ def test_send_message(client): trace_name = op_name_from_ref(call.op_name) assert trace_name == "google.generativeai.GenerativeModel.generate_content" assert "executable_code" in str(call.output).lower() + + +@pytest.mark.retry(max_attempts=5) +@pytest.mark.asyncio +@pytest.mark.skip_clickhouse_client +async def test_send_message_async(client): + import google.generativeai as genai + + genai.configure(api_key=os.getenv("GOOGLE_GENAI_KEY")) + model = genai.GenerativeModel(model_name="gemini-1.5-pro", tools="code_execution") + chat = model.start_chat() + await chat.send_message_async( + ( + "What is the sum of the first 50 prime numbers? " + "Generate and run code for the calculation, and make sure you get all 50." + ) + ) + + calls = list(client.calls()) + assert len(calls) == 2 + + call = calls[0] + assert call.started_at < call.ended_at + trace_name = op_name_from_ref(call.op_name) + assert trace_name == "google.generativeai.ChatSession.send_message_async" + assert "executable_code" in str(call.output).lower() + + call = calls[1] + assert call.started_at < call.ended_at + trace_name = op_name_from_ref(call.op_name) + assert trace_name == "google.generativeai.GenerativeModel.generate_content_async" + assert "executable_code" in str(call.output).lower()