Skip to content

Commit

Permalink
Merge branch 'master' into add_more_scorers
Browse files Browse the repository at this point in the history
  • Loading branch information
morganmcg1 authored Oct 16, 2024
2 parents b7058e1 + eb08ded commit 7c1f50a
Show file tree
Hide file tree
Showing 20 changed files with 915 additions and 339 deletions.
3 changes: 0 additions & 3 deletions .github/workflows/release.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,6 @@ jobs:
- name: Build dist
run: |
uv build
- name: Make docs
run: |
make docs
- name: upload test distribution
uses: pypa/gh-action-pypi-publish@release/v1
if: ${{ inputs.is_test }}
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ module = "weave_query.*"
ignore_errors = true

[tool.bumpversion]
current_version = "0.51.15-dev0"
current_version = "0.51.17-dev0"
parse = """(?x)
(?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\.
Expand Down
41 changes: 28 additions & 13 deletions tests/integrations/dspy/dspy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,22 +106,37 @@ def test_dspy_inline_signatures(client: WeaveClient) -> None:
call_1, _ = flattened_calls[0]
assert call_1.exception is None and call_1.ended_at is not None
output_1 = call_1.output
assert (
output_1
== """Prediction(
sentiment='Positive'
)"""
)

assert output_1 == {
"__class__": {
"module": "dspy.primitives.prediction",
"qualname": "Prediction",
"name": "Prediction",
},
"completions": {
"__class__": {
"module": "dspy.primitives.prediction",
"qualname": "Completions",
"name": "Completions",
},
},
}
call_2, _ = flattened_calls[1]
assert call_2.exception is None and call_2.ended_at is not None
output_2 = call_2.output
assert (
output_2
== """Prediction(
sentiment='Positive'
)"""
)
assert output_2 == {
"__class__": {
"module": "dspy.primitives.prediction",
"qualname": "Prediction",
"name": "Prediction",
},
"completions": {
"__class__": {
"module": "dspy.primitives.prediction",
"qualname": "Completions",
"name": "Completions",
},
},
}

call_3, _ = flattened_calls[2]
assert call_3.exception is None and call_3.ended_at is not None
Expand Down
86 changes: 65 additions & 21 deletions tests/integrations/google_ai_studio/google_ai_studio_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,54 @@
from weave.integrations.integration_utilities import op_name_from_ref


# NOTE: These asserts are slightly more relaxed than other integrations because we can't yet save
# the output with vcrpy. When VCR.py supports GRPC, we should add recordings for these tests!
# NOTE: We have retries because these tests are not deterministic (they use the live Gemini APIs),
# which can sometimes fail unexpectedly.
def assert_correct_output_shape(output: dict):
assert "candidates" in output
assert isinstance(output["candidates"], list)
for candidate in output["candidates"]:
assert isinstance(parts := candidate["content"]["parts"], list)
for part in parts:
assert isinstance(part["text"], str)

# https://cloud.google.com/vertex-ai/generative-ai/docs/reference/python/latest/vertexai.preview.generative_models.FinishReason
# 0 is FINISH_REASON_UNSPECIFIED
# 1 is STOP
assert candidate["finish_reason"] in (0, 1)
assert isinstance(candidate["index"], int)
assert isinstance(candidate["safety_ratings"], list)
assert isinstance(candidate["token_count"], int)
assert isinstance(candidate["grounding_attributions"], list)
assert isinstance(candidate["avg_logprobs"], float)
assert isinstance(output["usage_metadata"], dict)
assert isinstance(output["usage_metadata"]["prompt_token_count"], int)
assert isinstance(output["usage_metadata"]["candidates_token_count"], int)
assert isinstance(output["usage_metadata"]["total_token_count"], int)
assert isinstance(output["usage_metadata"]["cached_content_token_count"], int)


def assert_correct_summary(summary: dict, trace_name: str):
assert "usage" in summary
assert "gemini-1.5-flash" in summary["usage"]
assert summary["usage"]["gemini-1.5-flash"]["requests"] == 1
assert summary["usage"]["gemini-1.5-flash"]["prompt_tokens"] > 0
assert summary["usage"]["gemini-1.5-flash"]["completion_tokens"] > 0
assert summary["usage"]["gemini-1.5-flash"]["total_tokens"] > 0

assert "weave" in summary
assert summary["weave"]["status"] == "success"
assert summary["weave"]["trace_name"] == trace_name
assert summary["weave"]["latency_ms"] > 0


@pytest.mark.retry(max_attempts=5)
@pytest.mark.skip_clickhouse_client
def test_content_generation(client):
import google.generativeai as genai

genai.configure(api_key=os.environ.get("GOOGLE_GENAI_KEY"))
genai.configure(api_key=os.getenv("GOOGLE_GENAI_KEY"))
model = genai.GenerativeModel("gemini-1.5-flash")
model.generate_content("Explain how AI works in simple terms")

Expand All @@ -18,19 +61,20 @@ def test_content_generation(client):

call = calls[0]
assert call.started_at < call.ended_at
assert (
op_name_from_ref(call.op_name)
== "google.generativeai.GenerativeModel.generate_content"
)
output = call.output
assert output is not None

trace_name = op_name_from_ref(call.op_name)
assert trace_name == "google.generativeai.GenerativeModel.generate_content"
assert call.output is not None
assert_correct_output_shape(call.output)
assert_correct_summary(call.summary, trace_name)


@pytest.mark.retry(max_attempts=5)
@pytest.mark.skip_clickhouse_client
def test_content_generation_stream(client):
import google.generativeai as genai

genai.configure(api_key=os.environ.get("GOOGLE_GENAI_KEY"))
genai.configure(api_key=os.getenv("GOOGLE_GENAI_KEY"))
model = genai.GenerativeModel("gemini-1.5-flash")
response = model.generate_content(
"Explain how AI works in simple terms", stream=True
Expand All @@ -43,20 +87,21 @@ def test_content_generation_stream(client):

call = calls[0]
assert call.started_at < call.ended_at
assert (
op_name_from_ref(call.op_name)
== "google.generativeai.GenerativeModel.generate_content"
)
output = call.output
assert output is not None

trace_name = op_name_from_ref(call.op_name)
assert trace_name == "google.generativeai.GenerativeModel.generate_content"
assert call.output is not None
assert_correct_output_shape(call.output)
assert_correct_summary(call.summary, trace_name)


@pytest.mark.retry(max_attempts=5)
@pytest.mark.asyncio
@pytest.mark.skip_clickhouse_client
async def test_content_generation_async(client):
import google.generativeai as genai

genai.configure(api_key=os.environ.get("GOOGLE_GENAI_KEY"))
genai.configure(api_key=os.getenv("GOOGLE_GENAI_KEY"))
model = genai.GenerativeModel("gemini-1.5-flash")

_ = await model.generate_content_async("Explain how AI works in simple terms")
Expand All @@ -66,9 +111,8 @@ async def test_content_generation_async(client):

call = calls[0]
assert call.started_at < call.ended_at
assert (
op_name_from_ref(call.op_name)
== "google.generativeai.GenerativeModel.generate_content_async"
)
output = call.output
assert output is not None
trace_name = op_name_from_ref(call.op_name)
assert trace_name == "google.generativeai.GenerativeModel.generate_content_async"
assert call.output is not None
assert_correct_output_shape(call.output)
assert_correct_summary(call.summary, trace_name)
56 changes: 41 additions & 15 deletions tests/trace/test_client_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -1477,25 +1477,23 @@ async def predict(self, x):


def test_unknown_input_and_output_types(client):
class MyUnserializableClassA:
class MyUnknownClassA:
a_val: float

def __init__(self, a_val) -> None:
self.a_val = a_val

class MyUnserializableClassB:
class MyUnknownClassB:
b_val: float

def __init__(self, b_val) -> None:
self.b_val = b_val

@weave.op()
def op_with_unknown_types(
a: MyUnserializableClassA, b: float
) -> MyUnserializableClassB:
return MyUnserializableClassB(a.a_val + b)
def op_with_unknown_types(a: MyUnknownClassA, b: float) -> MyUnknownClassB:
return MyUnknownClassB(a.a_val + b)

a = MyUnserializableClassA(3)
a = MyUnknownClassA(3)
res = op_with_unknown_types(a, 0.14)

assert res.b_val == 3.14
Expand All @@ -1508,25 +1506,39 @@ def op_with_unknown_types(

assert len(inner_res.calls) == 1
assert inner_res.calls[0].inputs == {
"a": repr(a),
"a": {
"__class__": {
"module": "test_client_trace",
"qualname": "test_unknown_input_and_output_types.<locals>.MyUnknownClassA",
"name": "MyUnknownClassA",
},
"a_val": 3,
},
"b": 0.14,
}
assert inner_res.calls[0].output == repr(res)
assert inner_res.calls[0].output == {
"__class__": {
"module": "test_client_trace",
"qualname": "test_unknown_input_and_output_types.<locals>.MyUnknownClassB",
"name": "MyUnknownClassB",
},
"b_val": 3.14,
}


def test_unknown_attribute(client):
class MyUnserializableClass:
class MyUnknownClass:
val: int

def __init__(self, a_val) -> None:
self.a_val = a_val

class MySerializableClass(weave.Object):
obj: MyUnserializableClass
obj: MyUnknownClass

a_obj = MyUnserializableClass(1)
a_obj = MyUnknownClass(1)
a = MySerializableClass(obj=a_obj)
b_obj = MyUnserializableClass(2)
b_obj = MyUnknownClass(2)
b = MySerializableClass(obj=b_obj)

ref_a = weave.publish(a)
Expand All @@ -1535,8 +1547,22 @@ class MySerializableClass(weave.Object):
a2 = weave.ref(ref_a.uri()).get()
b2 = weave.ref(ref_b.uri()).get()

assert a2.obj == repr(a_obj)
assert b2.obj == repr(b_obj)
assert a2.obj == {
"__class__": {
"module": "test_client_trace",
"qualname": "test_unknown_attribute.<locals>.MyUnknownClass",
"name": "MyUnknownClass",
},
"a_val": 1,
}
assert b2.obj == {
"__class__": {
"module": "test_client_trace",
"qualname": "test_unknown_attribute.<locals>.MyUnknownClass",
"name": "MyUnknownClass",
},
"a_val": 2,
}


@contextmanager
Expand Down
6 changes: 3 additions & 3 deletions tests/trace/test_evaluations.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,9 +768,9 @@ def function_score(image, dc, model, obj, text, output) -> bool:
# So this assertion is checking current state, but not
# the correct behavior of the dataset (the should be the
# MyDataclass, MyModel, and MyObj)
assert isinstance(row["dc"], str) # MyDataclass
assert isinstance(row["model"], str) # MyModel
assert isinstance(row["obj"], str) # MyObj
assert isinstance(row["dc"], dict) # MyDataclass
assert isinstance(row["model"], dict) # MyModel
assert isinstance(row["obj"], dict) # MyObj
assert isinstance(row["text"], str)

access_log = client.server.attribute_access_log
Expand Down
Loading

0 comments on commit 7c1f50a

Please sign in to comment.