Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
tssweeney committed Oct 30, 2024
1 parent 298ffc5 commit f6be1ce
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 35 deletions.
151 changes: 118 additions & 33 deletions tests/trace/test_base_object_classes.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
"""
This test file ensures the base_object_classes behavior is as expected. Specifically:
1. We ensure that pythonic publishing and getting of objects:
a. Results in the correct base_object_class filter in the query.
b. Produces identical results.
2. We ensure that using the low-level interface:
a. Results in the correct base_object_class filter in the query.
b. Produces identical results.
3. We ensure that digests are equivalent between pythonic and interface style creation.
This is important to ensure that UI-based generation of objects is consistent with
programmatic generation.
4. We ensure that invalid schemas are properly rejected from the server.
"""

import pytest

import weave
Expand All @@ -19,41 +33,48 @@ def test_pythonic_creation(client: WeaveClient):

top_obj_gotten = weave.ref(ref.uri()).get()

assert isinstance(top_obj_gotten, base_objects.TestOnlyExample)
assert top_obj_gotten.model_dump() == top_obj.model_dump()

objs = client.server.obj_query(
tsi.ObjQueryReq.model_validate({
"project_id": client._project_id(),
"filter": {"base_object_classes": ["TestOnlyExample"]}},
tsi.ObjQueryReq.model_validate(
{
"project_id": client._project_id(),
"filter": {"base_object_classes": ["TestOnlyExample"]},
},
)
)

assert len(objs) == 1
assert objs[0].val == top_obj.model_dump()


objs = client.server.obj_query(
tsi.ObjQueryReq.model_validate({
"project_id": client._project_id(),
"filter": {"base_object_classes": ["TestOnlyNestedBaseObject"]}},
tsi.ObjQueryReq.model_validate(
{
"project_id": client._project_id(),
"filter": {"base_object_classes": ["TestOnlyNestedBaseObject"]},
},
)
)

assert len(objs) == 1
assert objs[0].val == nested_obj.model_dump()


def test_interface_creation(client):
# Now we will do the equivant operation using low-level interface.
nested_obj_id = "nested_obj"
nested_obj = base_objects.TestOnlyNestedBaseObject(b=3)
nested_obj_res = client.server.obj_create(
tsi.ObjCreateReq.model_validate({
"obj": {
"project_id": client._project_id(),
"object_id": nested_obj_id,
"val": nested_obj.model_dump(),
tsi.ObjCreateReq.model_validate(
{
"obj": {
"project_id": client._project_id(),
"object_id": nested_obj_id,
"val": nested_obj.model_dump(),
}
}
})
)
)
nested_obj_ref = ObjectRef(
entity=client.entity,
Expand All @@ -69,13 +90,15 @@ def test_interface_creation(client):
nested_obj=nested_obj_ref.uri(),
)
top_obj_res = client.server.obj_create(
tsi.ObjCreateReq.model_validate({
"obj": {
"project_id": client._project_id(),
"object_id": top_level_obj_id,
"val": top_obj.model_dump(),
tsi.ObjCreateReq.model_validate(
{
"obj": {
"project_id": client._project_id(),
"object_id": top_level_obj_id,
"val": top_obj.model_dump(),
}
}
})
)
)
top_obj_ref = ObjectRef(
entity=client.entity,
Expand All @@ -93,35 +116,97 @@ def test_interface_creation(client):
assert nested_obj_gotten.model_dump() == nested_obj.model_dump()

objs = client.server.obj_query(
tsi.ObjQueryReq.model_validate({
"project_id": client._project_id(),
"filter": {"base_object_classes": ["TestOnlyExample"]}},
tsi.ObjQueryReq.model_validate(
{
"project_id": client._project_id(),
"filter": {"base_object_classes": ["TestOnlyExample"]},
},
)
)

assert len(objs) == 1
assert objs[0].val == top_obj.model_dump()


objs = client.server.obj_query(
tsi.ObjQueryReq.model_validate({
"project_id": client._project_id(),
"filter": {"base_object_classes": ["TestOnlyNestedBaseObject"]}},
tsi.ObjQueryReq.model_validate(
{
"project_id": client._project_id(),
"filter": {"base_object_classes": ["TestOnlyNestedBaseObject"]},
},
)
)

assert len(objs) == 1
assert objs[0].val == nested_obj.model_dump()


def test_digest_equality(client):
# Next, let's make sure that the digests are all equivalent

nested_obj = base_objects.TestOnlyNestedBaseObject(b=3)
top_obj = base_objects.TestOnlyExample(
primitive=1,
nested_base_model=base_objects.TestOnlyNestedBaseModel(a=2),
nested_obj=weave.publish(nested_obj).uri(),
)
ref = weave.publish(top_obj)
pythonic_digest = ref.digest

# Now we will do the equivant operation using low-level interface.
nested_obj_id = "nested_obj"
nested_obj = base_objects.TestOnlyNestedBaseObject(b=3)
nested_obj_res = client.server.obj_create(
tsi.ObjCreateReq.model_validate(
{
"obj": {
"project_id": client._project_id(),
"object_id": nested_obj_id,
"val": nested_obj.model_dump(),
}
}
)
)
nested_obj_ref = ObjectRef(
entity=client.entity,
project=client.project,
name=nested_obj_id,
digest=nested_obj_res.digest,
)

top_level_obj_id = "top_obj"
top_obj = base_objects.TestOnlyExample(
primitive=1,
nested_base_model=base_objects.TestOnlyNestedBaseModel(a=2),
nested_obj=nested_obj_ref.uri(),
)
top_obj_res = client.server.obj_create(
tsi.ObjCreateReq.model_validate(
{
"obj": {
"project_id": client._project_id(),
"object_id": top_level_obj_id,
"val": top_obj.model_dump(),
}
}
)
)

interface_style_digest = top_obj_res.digest

assert pythonic_digest == interface_style_digest


def test_schema_validation(client):
# Test that we can't create an object with the wrong schema
with pytest.raises(weave.errors.WeaveError):
with pytest.raises():
client.server.obj_create(
tsi.ObjCreateReq.model_validate({
"obj": {
"project_id": client._project_id(),
"object_id": "nested_obj",
"val": {"a": 2},
tsi.ObjCreateReq.model_validate(
{
"obj": {
"project_id": client._project_id(),
"object_id": "nested_obj",
"val": {"a": 2},
}
}
})
)
)
2 changes: 1 addition & 1 deletion weave/trace/weave_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1232,7 +1232,7 @@ def _save_nested_objects(self, obj: Any, name: Optional[str] = None) -> Any:
# Case 1: Object:
# Here we recurse into each of the properties of the object
# and save them, and then save the object itself.
if isinstance(obj, Object): # TODO: add the generic object here
if isinstance(obj, Object): # TODO: add the generic object here
obj_rec = pydantic_object_record(obj)
for v in obj_rec.__dict__.values():
self._save_nested_objects(v)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

RefStr = str


class BaseObject(pydantic.BaseModel):
name: Optional[str] = None
description: Optional[str] = None
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ class TestOnlyExample(base_object_def.BaseObject):
nested_base_object: base_object_def.RefStr


__all__ = ["TestOnlyExample", "TestOnlyNestedBaseObject", "TestOnlyNestedBaseModel"]
__all__ = ["TestOnlyExample", "TestOnlyNestedBaseObject", "TestOnlyNestedBaseModel"]

0 comments on commit f6be1ce

Please sign in to comment.