Skip to content

Commit

Permalink
update: tests
Browse files Browse the repository at this point in the history
  • Loading branch information
soumik12345 committed Aug 16, 2024
1 parent fa67c4f commit b80b093
Show file tree
Hide file tree
Showing 6 changed files with 2,744 additions and 316 deletions.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

9 changes: 0 additions & 9 deletions weave/integrations/google_ai_studio/google_ai_studio_sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,5 @@ async def _async_wrapper(*args: Any, **kwargs: Any) -> Any:
name="google.generativeai.GenerativeModel.generate_content_async"
),
),
SymbolPatcher(
lambda: importlib.import_module(
"google.ai.generativelanguage_v1beta.services.generative_service.client"
),
"GenerativeServiceClient.generate_content",
gemini_wrapper_sync(
name="google.ai.generativelanguage_v1beta.services.generative_service.client.GenerativeServiceClient.generate_content"
),
),
]
)
117 changes: 42 additions & 75 deletions weave/integrations/google_ai_studio/google_ai_studio_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest

import weave
from weave.weave_client import WeaveClient
from weave.trace_server import trace_server_interface as tsi


Expand Down Expand Up @@ -43,36 +44,26 @@ def op_name_from_ref(ref: str) -> str:
filter_headers=["authorization", "x-api-key"],
allowed_hosts=["api.wandb.ai", "localhost", "trace.wandb.ai"],
)
def test_content_generation(client: weave.weave_client.WeaveClient) -> None:
def test_content_generation(client: WeaveClient) -> None:
import google.generativeai as genai

genai.configure(api_key=os.environ.get("GOOGLE_API_KEY", "DUMMY_API_KEY"))
model = genai.GenerativeModel("gemini-1.5-flash")
response = model.generate_content("Write a story about an AI and magic")

weave_server_respose = client.server.calls_query(
weave_server_response = client.server.calls_query(
tsi.CallsQueryReq(project_id=client._project_id())
)
assert len(weave_server_respose.calls) == 4
assert len(weave_server_response.calls) == 1

flatened_calls_list = [
flattened_calls_list = [
(op_name_from_ref(c.op_name), d)
for (c, d) in flatten_calls(weave_server_respose.calls)
for (c, d) in flatten_calls(weave_server_response.calls)
]
assert flatened_calls_list == [
("google.generativeai.GenerativeModel.start_chat", 0),
("google.generativeai.GenerativeModel.generate_content", 1),
(
"google.ai.generativelanguage_v1beta.services.generative_service.client.GenerativeServiceClient.generate_content",
2,
),
(
"google.generativeai.types.generation_types.GenerateContentResponse.from_response",
2,
),
assert flattened_calls_list == [
("google.generativeai.GenerativeModel.generate_content", 0),
]

for call in weave_server_respose.calls:
for call in weave_server_response.calls:
assert call.exception is None and call.ended_at is not None


Expand All @@ -81,7 +72,7 @@ def test_content_generation(client: weave.weave_client.WeaveClient) -> None:
filter_headers=["authorization", "x-api-key"],
allowed_hosts=["api.wandb.ai", "localhost", "trace.wandb.ai"],
)
def test_content_generation_stream(client: weave.weave_client.WeaveClient) -> None:
def test_content_generation_stream(client: WeaveClient) -> None:
import google.generativeai as genai

genai.configure(api_key=os.environ.get("GOOGLE_API_KEY", "DUMMY_API_KEY"))
Expand All @@ -90,27 +81,21 @@ def test_content_generation_stream(client: weave.weave_client.WeaveClient) -> No
"Write a story about an AI and magic", stream=True
)
chunks = [chunk.text for chunk in response]
assert len(chunks) > 1

weave_server_respose = client.server.calls_query(
weave_server_response = client.server.calls_query(
tsi.CallsQueryReq(project_id=client._project_id())
)
assert len(weave_server_respose.calls) >= 4
assert len(weave_server_response.calls) == 1

flatened_calls_list = [
flattened_calls_list = [
(op_name_from_ref(c.op_name), d)
for (c, d) in flatten_calls(weave_server_respose.calls)
for (c, d) in flatten_calls(weave_server_response.calls)
]
assert flattened_calls_list == [
("google.generativeai.GenerativeModel.generate_content", 0)
]
assert ("google.generativeai.GenerativeModel.start_chat", 0) in flatened_calls_list
assert (
"google.generativeai.GenerativeModel.generate_content",
1,
) in flatened_calls_list
assert (
"google.generativeai.types.generation_types.GenerateContentResponse.from_response",
2,
) in flatened_calls_list

for call in weave_server_respose.calls:
for call in weave_server_response.calls:
assert call.exception is None and call.ended_at is not None


Expand All @@ -119,39 +104,30 @@ def test_content_generation_stream(client: weave.weave_client.WeaveClient) -> No
filter_headers=["authorization", "x-api-key"],
allowed_hosts=["api.wandb.ai", "localhost", "trace.wandb.ai"],
)
def test_content_generation_async(client: weave.weave_client.WeaveClient) -> None:
def test_content_generation_async(client: WeaveClient) -> None:
import google.generativeai as genai

genai.configure(api_key=os.environ.get("GOOGLE_API_KEY", "DUMMY_API_KEY"))
model = genai.GenerativeModel("gemini-1.5-flash")

async def async_generate():
response = await model.generate_content_async(
"Write a story about an AI and magic"
)
return response
return await model.generate_content_async("Write a story about an AI and magic")

response = asyncio.run(async_generate())
asyncio.run(async_generate())

weave_server_respose = client.server.calls_query(
weave_server_response = client.server.calls_query(
tsi.CallsQueryReq(project_id=client._project_id())
)
assert len(weave_server_respose.calls) == 2
assert len(weave_server_response.calls) == 1

flatened_calls_list = [
flattened_calls_list = [
(op_name_from_ref(c.op_name), d)
for (c, d) in flatten_calls(weave_server_respose.calls)
for (c, d) in flatten_calls(weave_server_response.calls)
]

assert flatened_calls_list == [
assert flattened_calls_list == [
("google.generativeai.GenerativeModel.generate_content_async", 0),
(
"google.generativeai.types.generation_types.GenerateContentResponse.from_response",
1,
),
]

for call in weave_server_respose.calls:
for call in weave_server_response.calls:
assert call.exception is None and call.ended_at is not None


Expand All @@ -160,42 +136,33 @@ async def async_generate():
filter_headers=["authorization", "x-api-key"],
allowed_hosts=["api.wandb.ai", "localhost", "trace.wandb.ai"],
)
def test_content_generation_async_stream(
client: weave.weave_client.WeaveClient,
) -> None:
def test_content_generation_async_stream(client: WeaveClient) -> None:
import google.generativeai as genai

genai.configure(api_key=os.environ.get("GOOGLE_API_KEY", "DUMMY_API_KEY"))
model = genai.GenerativeModel("gemini-1.5-flash")

async def get_response():
chunks = []
async for chunk in await model.generate_content_async(
"Write a cute story about cats.", stream=True
"Write a story about an AI and magic", stream=True
):
chunks.append(chunk)
return chunks
if chunk.text:
print(chunk.text)
print("_" * 80)

chunks = asyncio.run(get_response())
asyncio.run(get_response())

weave_server_respose = client.server.calls_query(
weave_server_response = client.server.calls_query(
tsi.CallsQueryReq(project_id=client._project_id())
)
assert len(weave_server_respose.calls) >= 2
assert len(weave_server_response.calls) == 1

flatened_calls_list = [
flattened_calls_list = [
(op_name_from_ref(c.op_name), d)
for (c, d) in flatten_calls(weave_server_respose.calls)
for (c, d) in flatten_calls(weave_server_response.calls)
]

assert (
"google.generativeai.GenerativeModel.generate_content_async",
0,
) in flatened_calls_list
assert (
"google.generativeai.types.generation_types.GenerateContentResponse.from_response",
1,
) in flatened_calls_list

for call in weave_server_respose.calls:
assert flattened_calls_list == [
("google.generativeai.GenerativeModel.generate_content_async", 0)
]
for call in weave_server_response.calls:
assert call.exception is None and call.ended_at is not None

0 comments on commit b80b093

Please sign in to comment.