diff --git a/tests/integrations/google_ai_studio/google_ai_studio_test.py b/tests/integrations/google_ai_studio/google_ai_studio_test.py index 01a610d4117..11e947e6255 100644 --- a/tests/integrations/google_ai_studio/google_ai_studio_test.py +++ b/tests/integrations/google_ai_studio/google_ai_studio_test.py @@ -54,7 +54,7 @@ def test_content_generation(client): genai.configure(api_key=os.getenv("GOOGLE_GENAI_KEY")) model = genai.GenerativeModel("gemini-1.5-flash") - model.generate_content("Explain how AI works in simple terms") + model.generate_content("What is the capital of France?") calls = list(client.calls()) assert len(calls) == 1 @@ -64,7 +64,7 @@ def test_content_generation(client): trace_name = op_name_from_ref(call.op_name) assert trace_name == "google.generativeai.GenerativeModel.generate_content" - assert call.output is not None + assert "paris" in str(call.output).lower() # TODO: Re-enable after dictify is fixed # assert_correct_output_shape(call.output) # assert_correct_summary(call.summary, trace_name) @@ -77,9 +77,7 @@ def test_content_generation_stream(client): genai.configure(api_key=os.getenv("GOOGLE_GENAI_KEY")) model = genai.GenerativeModel("gemini-1.5-flash") - response = model.generate_content( - "Explain how AI works in simple terms", stream=True - ) + response = model.generate_content("What is the capital of France?", stream=True) chunks = [chunk.text for chunk in response] assert len(chunks) > 1 @@ -91,7 +89,7 @@ def test_content_generation_stream(client): trace_name = op_name_from_ref(call.op_name) assert trace_name == "google.generativeai.GenerativeModel.generate_content" - assert call.output is not None + assert "paris" in str(call.output).lower() # TODO: Re-enable after dictify is fixed # assert_correct_output_shape(call.output) # assert_correct_summary(call.summary, trace_name) @@ -106,7 +104,7 @@ async def test_content_generation_async(client): genai.configure(api_key=os.getenv("GOOGLE_GENAI_KEY")) model = genai.GenerativeModel("gemini-1.5-flash") - _ = await model.generate_content_async("Explain how AI works in simple terms") + _ = await model.generate_content_async("What is the capital of France?") calls = list(client.calls()) assert len(calls) == 1 @@ -115,7 +113,7 @@ async def test_content_generation_async(client): assert call.started_at < call.ended_at trace_name = op_name_from_ref(call.op_name) assert trace_name == "google.generativeai.GenerativeModel.generate_content_async" - assert call.output is not None + assert "paris" in str(call.output).lower() # TODO: Re-enable after dictify is fixed # assert_correct_output_shape(call.output) # assert_correct_summary(call.summary, trace_name) @@ -143,10 +141,10 @@ def test_send_message(client): assert call.started_at < call.ended_at trace_name = op_name_from_ref(call.op_name) assert trace_name == "google.generativeai.ChatSession.send_message" - assert call.output is not None + assert "executable_code" in str(call.output).lower() call = calls[1] assert call.started_at < call.ended_at trace_name = op_name_from_ref(call.op_name) assert trace_name == "google.generativeai.GenerativeModel.generate_content" - assert call.output is not None + assert "executable_code" in str(call.output).lower()