diff --git a/tests/trace/test_base_object_classes.py b/tests/trace/test_base_object_classes.py index 5d6df3600ea..c0c9ba25b14 100644 --- a/tests/trace/test_base_object_classes.py +++ b/tests/trace/test_base_object_classes.py @@ -1,3 +1,17 @@ +""" +This test file ensures the base_object_classes behavior is as expected. Specifically: +1. We ensure that pythonic publishing and getting of objects: + a. Results in the correct base_object_class filter in the query. + b. Produces identical results. +2. We ensure that using the low-level interface: + a. Results in the correct base_object_class filter in the query. + b. Produces identical results. +3. We ensure that digests are equivalent between pythonic and interface style creation. + This is important to ensure that UI-based generation of objects is consistent with + programmatic generation. +4. We ensure that invalid schemas are properly rejected from the server. +""" + import pytest import weave @@ -19,41 +33,48 @@ def test_pythonic_creation(client: WeaveClient): top_obj_gotten = weave.ref(ref.uri()).get() + assert isinstance(top_obj_gotten, base_objects.TestOnlyExample) assert top_obj_gotten.model_dump() == top_obj.model_dump() objs = client.server.obj_query( - tsi.ObjQueryReq.model_validate({ - "project_id": client._project_id(), - "filter": {"base_object_classes": ["TestOnlyExample"]}}, + tsi.ObjQueryReq.model_validate( + { + "project_id": client._project_id(), + "filter": {"base_object_classes": ["TestOnlyExample"]}, + }, ) ) assert len(objs) == 1 assert objs[0].val == top_obj.model_dump() - objs = client.server.obj_query( - tsi.ObjQueryReq.model_validate({ - "project_id": client._project_id(), - "filter": {"base_object_classes": ["TestOnlyNestedBaseObject"]}}, + tsi.ObjQueryReq.model_validate( + { + "project_id": client._project_id(), + "filter": {"base_object_classes": ["TestOnlyNestedBaseObject"]}, + }, ) ) assert len(objs) == 1 assert objs[0].val == nested_obj.model_dump() + def test_interface_creation(client): # Now we will do the equivant operation using low-level interface. nested_obj_id = "nested_obj" nested_obj = base_objects.TestOnlyNestedBaseObject(b=3) nested_obj_res = client.server.obj_create( - tsi.ObjCreateReq.model_validate({ - "obj": { - "project_id": client._project_id(), - "object_id": nested_obj_id, - "val": nested_obj.model_dump(), + tsi.ObjCreateReq.model_validate( + { + "obj": { + "project_id": client._project_id(), + "object_id": nested_obj_id, + "val": nested_obj.model_dump(), + } } - }) + ) ) nested_obj_ref = ObjectRef( entity=client.entity, @@ -69,13 +90,15 @@ def test_interface_creation(client): nested_obj=nested_obj_ref.uri(), ) top_obj_res = client.server.obj_create( - tsi.ObjCreateReq.model_validate({ - "obj": { - "project_id": client._project_id(), - "object_id": top_level_obj_id, - "val": top_obj.model_dump(), + tsi.ObjCreateReq.model_validate( + { + "obj": { + "project_id": client._project_id(), + "object_id": top_level_obj_id, + "val": top_obj.model_dump(), + } } - }) + ) ) top_obj_ref = ObjectRef( entity=client.entity, @@ -93,35 +116,97 @@ def test_interface_creation(client): assert nested_obj_gotten.model_dump() == nested_obj.model_dump() objs = client.server.obj_query( - tsi.ObjQueryReq.model_validate({ - "project_id": client._project_id(), - "filter": {"base_object_classes": ["TestOnlyExample"]}}, + tsi.ObjQueryReq.model_validate( + { + "project_id": client._project_id(), + "filter": {"base_object_classes": ["TestOnlyExample"]}, + }, ) ) assert len(objs) == 1 assert objs[0].val == top_obj.model_dump() - objs = client.server.obj_query( - tsi.ObjQueryReq.model_validate({ - "project_id": client._project_id(), - "filter": {"base_object_classes": ["TestOnlyNestedBaseObject"]}}, + tsi.ObjQueryReq.model_validate( + { + "project_id": client._project_id(), + "filter": {"base_object_classes": ["TestOnlyNestedBaseObject"]}, + }, ) ) assert len(objs) == 1 assert objs[0].val == nested_obj.model_dump() + +def test_digest_equality(client): + # Next, let's make sure that the digests are all equivalent + + nested_obj = base_objects.TestOnlyNestedBaseObject(b=3) + top_obj = base_objects.TestOnlyExample( + primitive=1, + nested_base_model=base_objects.TestOnlyNestedBaseModel(a=2), + nested_obj=weave.publish(nested_obj).uri(), + ) + ref = weave.publish(top_obj) + pythonic_digest = ref.digest + + # Now we will do the equivant operation using low-level interface. + nested_obj_id = "nested_obj" + nested_obj = base_objects.TestOnlyNestedBaseObject(b=3) + nested_obj_res = client.server.obj_create( + tsi.ObjCreateReq.model_validate( + { + "obj": { + "project_id": client._project_id(), + "object_id": nested_obj_id, + "val": nested_obj.model_dump(), + } + } + ) + ) + nested_obj_ref = ObjectRef( + entity=client.entity, + project=client.project, + name=nested_obj_id, + digest=nested_obj_res.digest, + ) + + top_level_obj_id = "top_obj" + top_obj = base_objects.TestOnlyExample( + primitive=1, + nested_base_model=base_objects.TestOnlyNestedBaseModel(a=2), + nested_obj=nested_obj_ref.uri(), + ) + top_obj_res = client.server.obj_create( + tsi.ObjCreateReq.model_validate( + { + "obj": { + "project_id": client._project_id(), + "object_id": top_level_obj_id, + "val": top_obj.model_dump(), + } + } + ) + ) + + interface_style_digest = top_obj_res.digest + + assert pythonic_digest == interface_style_digest + + def test_schema_validation(client): # Test that we can't create an object with the wrong schema - with pytest.raises(weave.errors.WeaveError): + with pytest.raises(): client.server.obj_create( - tsi.ObjCreateReq.model_validate({ - "obj": { - "project_id": client._project_id(), - "object_id": "nested_obj", - "val": {"a": 2}, + tsi.ObjCreateReq.model_validate( + { + "obj": { + "project_id": client._project_id(), + "object_id": "nested_obj", + "val": {"a": 2}, + } } - }) + ) ) diff --git a/weave/trace/weave_client.py b/weave/trace/weave_client.py index 57166f92fd1..656bf1c360f 100644 --- a/weave/trace/weave_client.py +++ b/weave/trace/weave_client.py @@ -1232,7 +1232,7 @@ def _save_nested_objects(self, obj: Any, name: Optional[str] = None) -> Any: # Case 1: Object: # Here we recurse into each of the properties of the object # and save them, and then save the object itself. - if isinstance(obj, Object): # TODO: add the generic object here + if isinstance(obj, Object): # TODO: add the generic object here obj_rec = pydantic_object_record(obj) for v in obj_rec.__dict__.values(): self._save_nested_objects(v) diff --git a/weave/trace_server/interface/base_object_classes/base_object_def.py b/weave/trace_server/interface/base_object_classes/base_object_def.py index b9ab4fa7a8d..746aae1df5f 100644 --- a/weave/trace_server/interface/base_object_classes/base_object_def.py +++ b/weave/trace_server/interface/base_object_classes/base_object_def.py @@ -4,6 +4,7 @@ RefStr = str + class BaseObject(pydantic.BaseModel): name: Optional[str] = None description: Optional[str] = None diff --git a/weave/trace_server/interface/base_object_classes/test_only_example.py b/weave/trace_server/interface/base_object_classes/test_only_example.py index dd355110bde..b93af28790e 100644 --- a/weave/trace_server/interface/base_object_classes/test_only_example.py +++ b/weave/trace_server/interface/base_object_classes/test_only_example.py @@ -20,4 +20,4 @@ class TestOnlyExample(base_object_def.BaseObject): nested_base_object: base_object_def.RefStr -__all__ = ["TestOnlyExample", "TestOnlyNestedBaseObject", "TestOnlyNestedBaseModel"] \ No newline at end of file +__all__ = ["TestOnlyExample", "TestOnlyNestedBaseObject", "TestOnlyNestedBaseModel"]