Skip to content

Commit

Permalink
update: tests
Browse files Browse the repository at this point in the history
  • Loading branch information
soumik12345 committed Oct 28, 2024
1 parent 7f87d50 commit c04ca38
Showing 1 changed file with 8 additions and 10 deletions.
18 changes: 8 additions & 10 deletions tests/integrations/google_ai_studio/google_ai_studio_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_content_generation(client):

genai.configure(api_key=os.getenv("GOOGLE_GENAI_KEY"))
model = genai.GenerativeModel("gemini-1.5-flash")
model.generate_content("Explain how AI works in simple terms")
model.generate_content("What is the capital of France?")

calls = list(client.calls())
assert len(calls) == 1
Expand All @@ -64,7 +64,7 @@ def test_content_generation(client):

trace_name = op_name_from_ref(call.op_name)
assert trace_name == "google.generativeai.GenerativeModel.generate_content"
assert call.output is not None
assert "paris" in str(call.output).lower()
# TODO: Re-enable after dictify is fixed
# assert_correct_output_shape(call.output)
# assert_correct_summary(call.summary, trace_name)
Expand All @@ -77,9 +77,7 @@ def test_content_generation_stream(client):

genai.configure(api_key=os.getenv("GOOGLE_GENAI_KEY"))
model = genai.GenerativeModel("gemini-1.5-flash")
response = model.generate_content(
"Explain how AI works in simple terms", stream=True
)
response = model.generate_content("What is the capital of France?", stream=True)
chunks = [chunk.text for chunk in response]
assert len(chunks) > 1

Expand All @@ -91,7 +89,7 @@ def test_content_generation_stream(client):

trace_name = op_name_from_ref(call.op_name)
assert trace_name == "google.generativeai.GenerativeModel.generate_content"
assert call.output is not None
assert "paris" in str(call.output).lower()
# TODO: Re-enable after dictify is fixed
# assert_correct_output_shape(call.output)
# assert_correct_summary(call.summary, trace_name)
Expand All @@ -106,7 +104,7 @@ async def test_content_generation_async(client):
genai.configure(api_key=os.getenv("GOOGLE_GENAI_KEY"))
model = genai.GenerativeModel("gemini-1.5-flash")

_ = await model.generate_content_async("Explain how AI works in simple terms")
_ = await model.generate_content_async("What is the capital of France?")

calls = list(client.calls())
assert len(calls) == 1
Expand All @@ -115,7 +113,7 @@ async def test_content_generation_async(client):
assert call.started_at < call.ended_at
trace_name = op_name_from_ref(call.op_name)
assert trace_name == "google.generativeai.GenerativeModel.generate_content_async"
assert call.output is not None
assert "paris" in str(call.output).lower()
# TODO: Re-enable after dictify is fixed
# assert_correct_output_shape(call.output)
# assert_correct_summary(call.summary, trace_name)
Expand Down Expand Up @@ -143,10 +141,10 @@ def test_send_message(client):
assert call.started_at < call.ended_at
trace_name = op_name_from_ref(call.op_name)
assert trace_name == "google.generativeai.ChatSession.send_message"
assert call.output is not None
assert "executable_code" in str(call.output).lower()

call = calls[1]
assert call.started_at < call.ended_at
trace_name = op_name_from_ref(call.op_name)
assert trace_name == "google.generativeai.GenerativeModel.generate_content"
assert call.output is not None
assert "executable_code" in str(call.output).lower()

0 comments on commit c04ca38

Please sign in to comment.