From ecf543a1ee103ce198d5d103635e33f72f75f8ff Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Thu, 12 Dec 2024 16:52:11 -0800 Subject: [PATCH] initial changes --- dev_docs/BaseObjectClasses.md | 8 +- tests/trace/test_base_object_classes.py | 14 +- .../wfReactInterface/baseObjectClassQuery.ts | 8 +- .../traceServerClientTypes.ts | 2 +- weave/trace_server/base_object_class_util.py | 121 -------------- .../clickhouse_trace_server_batched.py | 16 +- weave/trace_server/object_class_util.py | 149 ++++++++++++++++++ weave/trace_server/sqlite_trace_server.py | 13 +- weave/trace_server/trace_server_interface.py | 11 +- 9 files changed, 190 insertions(+), 152 deletions(-) delete mode 100644 weave/trace_server/base_object_class_util.py create mode 100644 weave/trace_server/object_class_util.py diff --git a/dev_docs/BaseObjectClasses.md b/dev_docs/BaseObjectClasses.md index a571af49755..05215583f17 100644 --- a/dev_docs/BaseObjectClasses.md +++ b/dev_docs/BaseObjectClasses.md @@ -116,7 +116,7 @@ curl -X POST 'https://trace.wandb.ai/obj/create' \ "project_id": "user/project", "object_id": "my_config", "val": {...}, - "set_base_object_class": "MyConfig" + "object_class": "MyConfig" } }' @@ -162,7 +162,7 @@ Run `make synchronize-base-object-schemas` to ensure the frontend TypeScript typ 4. Now, each use case uses different parts: 1. `Python Writing`. Users can directly import these classes and use them as normal Pydantic models, which get published with `weave.publish`. The python client correct builds the requisite payload. 2. `Python Reading`. Users can `weave.ref().get()` and the weave python SDK will return the instance with the correct type. Note: we do some special handling such that the returned object is not a WeaveObject, but literally the exact pydantic class. - 3. `HTTP Writing`. In cases where the client/user does not want to add the special type information, users can publish base objects by setting the `set_base_object_class` setting on `POST obj/create` to the name of the class. The weave server will validate the object against the schema, update the metadata fields, and store the object. + 3. `HTTP Writing`. In cases where the client/user does not want to add the special type information, users can publish builtin objects by setting the `object_class` setting on `POST obj/create` to the name of the class. The weave server will validate the object against the schema, update the metadata fields, and store the object. 4. `HTTP Reading`. When querying for objects, the server will return the object with the correct type if the `base_object_class` metadata field is set. 5. `Frontend`. The frontend will read the zod schema from `weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/generatedBaseObjectClasses.zod.ts` and use that to provide compile time type safety when using `useBaseObjectInstances` and runtime type safety when using `useCreateBaseObjectInstance`. * Note: it is critical that all techniques produce the same digest for the same data - which is tested in the tests. This way versions are not thrashed by different clients/users. @@ -185,7 +185,7 @@ graph TD subgraph "Trace Server" subgraph "HTTP API" - R --> |validates using| HW["POST obj/create
set_base_object_class"] + R --> |validates using| HW["POST obj/create
object_class"] HW --> DB[(Weave Object Store)] HR["POST objs/query
base_object_classes"] --> |Filters base_object_class| DB end @@ -203,7 +203,7 @@ graph TD Z --> |import| UBI["useBaseObjectInstances"] Z --> |import| UCI["useCreateBaseObjectInstance"] UBI --> |Filters base_object_class| HR - UCI --> |set_base_object_class| HW + UCI --> |object_class| HW UI[React UI] --> UBI UI --> UCI end diff --git a/tests/trace/test_base_object_classes.py b/tests/trace/test_base_object_classes.py index a264941f7b0..1a7e64f9ed6 100644 --- a/tests/trace/test_base_object_classes.py +++ b/tests/trace/test_base_object_classes.py @@ -139,7 +139,7 @@ def test_interface_creation(client): "project_id": client._project_id(), "object_id": nested_obj_id, "val": nested_obj.model_dump(), - "set_base_object_class": "TestOnlyNestedBaseObject", + "object_class": "TestOnlyNestedBaseObject", } } ) @@ -164,7 +164,7 @@ def test_interface_creation(client): "project_id": client._project_id(), "object_id": top_level_obj_id, "val": top_obj.model_dump(), - "set_base_object_class": "TestOnlyExample", + "object_class": "TestOnlyExample", } } ) @@ -271,7 +271,7 @@ def test_digest_equality(client): "project_id": client._project_id(), "object_id": nested_obj_id, "val": nested_obj.model_dump(), - "set_base_object_class": "TestOnlyNestedBaseObject", + "object_class": "TestOnlyNestedBaseObject", } } ) @@ -300,7 +300,7 @@ def test_digest_equality(client): "project_id": client._project_id(), "object_id": top_level_obj_id, "val": top_obj.model_dump(), - "set_base_object_class": "TestOnlyExample", + "object_class": "TestOnlyExample", } } ) @@ -322,7 +322,7 @@ def test_schema_validation(client): "object_id": "nested_obj", # Incorrect schema, should raise! "val": {"a": 2}, - "set_base_object_class": "TestOnlyNestedBaseObject", + "object_class": "TestOnlyNestedBaseObject", } } ) @@ -340,7 +340,7 @@ def test_schema_validation(client): "_class_name": "TestOnlyNestedBaseObject", "_bases": ["BaseObject", "BaseModel"], }, - "set_base_object_class": "TestOnlyNestedBaseObject", + "object_class": "TestOnlyNestedBaseObject", } } ) @@ -359,7 +359,7 @@ def test_schema_validation(client): "_class_name": "TestOnlyNestedBaseObject", "_bases": ["BaseObject", "BaseModel"], }, - "set_base_object_class": "TestOnlyExample", + "object_class": "TestOnlyExample", } } ) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/baseObjectClassQuery.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/baseObjectClassQuery.ts index 6ceb39daa70..d1ab91e10dc 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/baseObjectClassQuery.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/baseObjectClassQuery.ts @@ -113,11 +113,11 @@ export const createBaseObjectInstance = async < req: TraceObjCreateReq ): Promise => { if ( - req.obj.set_base_object_class != null && - req.obj.set_base_object_class !== baseObjectClassName + req.obj.object_class != null && + req.obj.object_class !== baseObjectClassName ) { throw new Error( - `set_base_object_class must match baseObjectClassName: ${baseObjectClassName}` + `object_class must match baseObjectClassName: ${baseObjectClassName}` ); } @@ -138,7 +138,7 @@ export const createBaseObjectInstance = async < ...req, obj: { ...req.obj, - set_base_object_class: baseObjectClassName, + object_class: baseObjectClassName, }, }; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerClientTypes.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerClientTypes.ts index c396962f0fb..17e06cf24ef 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerClientTypes.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerClientTypes.ts @@ -243,7 +243,7 @@ export type TraceObjCreateReq = { project_id: string; object_id: string; val: T; - set_base_object_class?: string; + object_class?: string; }; }; diff --git a/weave/trace_server/base_object_class_util.py b/weave/trace_server/base_object_class_util.py deleted file mode 100644 index 1c52f766c0c..00000000000 --- a/weave/trace_server/base_object_class_util.py +++ /dev/null @@ -1,121 +0,0 @@ -from typing import Any, Optional - -from pydantic import BaseModel - -from weave.trace_server.interface.base_object_classes.base_object_registry import ( - BASE_OBJECT_REGISTRY, -) - -""" -There are two standard base object classes: BaseObject and Object - -`Object` is the base class for the more advanced object-oriented `weave.Object` use cases. -`BaseObject` is the more simple schema-based base object class. -""" -base_object_class_names = ["BaseObject", "Object"] - - -def get_base_object_class(val: Any) -> Optional[str]: - if isinstance(val, dict): - if "_bases" in val: - if isinstance(val["_bases"], list): - if len(val["_bases"]) >= 2: - if val["_bases"][-1] == "BaseModel": - if val["_bases"][-2] in base_object_class_names: - if len(val["_bases"]) > 2: - return val["_bases"][-3] - elif "_class_name" in val: - return val["_class_name"] - return None - - -def process_incoming_object_val( - val: Any, req_base_object_class: Optional[str] = None -) -> tuple[dict, Optional[str]]: - """ - This method is responsible for accepting an incoming object from the user, validating it - against the base object class, and returning the object with the base object class - set. It does not mutate the original object, but returns a new object with values set if needed. - - Specifically,: - - 1. If the object is not a dict, it is returned as is, and the base object class is set to None. - 2. There are 2 ways to specify the base object class: - a. The `req_base_object_class` argument. - * used by non-pythonic writers of weave objects - b. The `_bases` & `_class_name` attributes of the object, which is a list of base class names. - * used by pythonic weave object writers (legacy) - 3. If the object has a base object class that does not match the requested base object class, - an error is thrown. - 4. if the object contains a base object class inside the payload, then we simply validate - the object against the base object class (if a match is found in BASE_OBJECT_REGISTRY) - 5. If the object does not have a base object class and a requested base object class is - provided, we require a match in BASE_OBJECT_REGISTRY and validate the object against - the requested base object class. Finally, we set the correct feilds. - """ - if not isinstance(val, dict): - if req_base_object_class is not None: - raise ValueError( - "set_base_object_class cannot be provided for non-dict objects" - ) - return val, None - - dict_val = val.copy() - val_base_object_class = get_base_object_class(dict_val) - - if ( - val_base_object_class != None - and req_base_object_class != None - and val_base_object_class != req_base_object_class - ): - raise ValueError( - f"set_base_object_class must match base_object_class: {val_base_object_class} != {req_base_object_class}" - ) - - if val_base_object_class is not None: - # In this case, we simply validate if the match is found - if base_object_class_type := BASE_OBJECT_REGISTRY.get(val_base_object_class): - base_object_class_type.model_validate(dict_val) - elif req_base_object_class is not None: - # In this case, we require that the base object class is registered - if base_object_class_type := BASE_OBJECT_REGISTRY.get(req_base_object_class): - dict_val = dump_base_object(base_object_class_type.model_validate(dict_val)) - else: - raise ValueError(f"Unknown base object class: {req_base_object_class}") - - base_object_class = val_base_object_class or req_base_object_class - - return dict_val, base_object_class - - -# Server-side version of `pydantic_object_record` -def dump_base_object(val: BaseModel) -> dict: - cls = val.__class__ - cls_name = val.__class__.__name__ - bases = [c.__name__ for c in cls.mro()[1:-1]] - - dump = {} - # Order matters here due to the way we calculate the digest! - # This matches the client - dump["_type"] = cls_name - for k in val.model_fields: - dump[k] = _general_dump(getattr(val, k)) - # yes, this is done twice, to match the client - dump["_class_name"] = cls_name - dump["_bases"] = bases - return dump - - -def _general_dump(val: Any) -> Any: - if isinstance(val, BaseModel): - return dump_base_object(val) - elif isinstance(val, dict): - return {k: _general_dump(v) for k, v in val.items()} - elif isinstance(val, list): - return [_general_dump(v) for v in val] - elif isinstance(val, tuple): - return tuple(_general_dump(v) for v in val) - elif isinstance(val, set): - return {_general_dump(v) for v in val} - else: - return val diff --git a/weave/trace_server/clickhouse_trace_server_batched.py b/weave/trace_server/clickhouse_trace_server_batched.py index d40d7bcc2a3..63b6a368ebe 100644 --- a/weave/trace_server/clickhouse_trace_server_batched.py +++ b/weave/trace_server/clickhouse_trace_server_batched.py @@ -45,7 +45,6 @@ from weave.trace_server import refs_internal as ri from weave.trace_server import trace_server_interface as tsi from weave.trace_server.actions_worker.dispatcher import execute_batch -from weave.trace_server.base_object_class_util import process_incoming_object_val from weave.trace_server.calls_query_builder import ( CallsQuery, HardCodedFilter, @@ -81,6 +80,7 @@ from weave.trace_server.model_providers.model_providers import ( read_model_to_provider_info_map, ) +from weave.trace_server.object_class_util import process_incoming_object_val from weave.trace_server.orm import ParamBuilder, Row from weave.trace_server.secret_fetcher_context import _secret_fetcher_context from weave.trace_server.table_query_builder import ( @@ -563,19 +563,19 @@ def ops_query(self, req: tsi.OpQueryReq) -> tsi.OpQueryRes: return tsi.OpQueryRes(op_objs=objs) def obj_create(self, req: tsi.ObjCreateReq) -> tsi.ObjCreateRes: - val, base_object_class = process_incoming_object_val( - req.obj.val, req.obj.set_base_object_class + processed_result = process_incoming_object_val( + req.obj.val, req.obj.object_class ) - - json_val = json.dumps(val) + processed_val = processed_result["val"] + json_val = json.dumps(processed_val) digest = str_digest(json_val) ch_obj = ObjCHInsertable( project_id=req.obj.project_id, object_id=req.obj.object_id, - kind=get_kind(val), - base_object_class=base_object_class, - refs=extract_refs_from_values(val), + kind=get_kind(processed_val), + base_object_class=processed_result["base_object_class"], + refs=extract_refs_from_values(processed_val), val_dump=json_val, digest=digest, ) diff --git a/weave/trace_server/object_class_util.py b/weave/trace_server/object_class_util.py new file mode 100644 index 00000000000..32758b1b84c --- /dev/null +++ b/weave/trace_server/object_class_util.py @@ -0,0 +1,149 @@ +from typing import Any, Optional, TypedDict + +from pydantic import BaseModel + +from weave.trace_server.interface.base_object_classes.base_object_registry import ( + BASE_OBJECT_REGISTRY, +) + +""" +There are two standard base object classes: BaseObject and Object + +`Object` is the base class for the more advanced object-oriented `weave.Object` use cases. +`BaseObject` is the more simple schema-based base object class. +""" +base_object_class_names = ["BaseObject", "Object"] + + +class GetObjectClassesResult(TypedDict): + # object_class is the "leaf" class of the val assuming it is a subclass of weave.Object or weave.BaseObject + object_class: Optional[str] + # base_object_class is the first subclass of weave.Object or weave.BaseObject + base_object_class: Optional[str] + + +def get_object_classes(val: Any) -> Optional[GetObjectClassesResult]: + if isinstance(val, dict): + if "_bases" in val: + if isinstance(val["_bases"], list): + if len(val["_bases"]) >= 2: + if val["_bases"][-1] == "BaseModel": + if val["_bases"][-2] in base_object_class_names: + object_class = val["_class_name"] + base_object_class = object_class + if len(val["_bases"]) > 2: + base_object_class = val["_bases"][-3] + return GetObjectClassesResult( + object_class=object_class, + base_object_class=base_object_class, + ) + return None + + +class ProcessIncomingObjectResult(TypedDict): + val: Any + base_object_class: Optional[str] + + +def process_incoming_object_val( + val: Any, req_object_class: Optional[str] = None +) -> ProcessIncomingObjectResult: + """ + This method is responsible for accepting an incoming object from the user and validating it + against the object class. It adds the _class_name and _bases keys correctly and returns the object + with the base object class set. It does not mutate the original object, but returns a new object + with values set if needed. + """ + # First, we ensure the object is a dict before processing it. + # If the object is not a dict, we return it as is and set the base_object_class to None. + if not isinstance(val, dict): + if req_object_class is not None: + raise ValueError("object_class cannot be provided for non-dict objects") + return ProcessIncomingObjectResult(val=val, base_object_class=None) + + # Next we extract the object classes from the object. the `_bases` and `_class_name` keys are + # special weave-added keys that tell us the class hierarchy of the object. + # _class_name is the name of the class of the object. + # _bases is a list of the class's superclasses. + # In the specific case that bases starts with `BaseModel` (Pydantic Root) and the next subclass + # is in `base_object_class_names`, then we assume it is a special Weave Object class. From there, + # we can extract the object class and base object class. + val_object_classes = get_object_classes(val) + + # In the event that we successfully extracted the object classes, we need to check if the + # requested object class matches the object class of the object. If it does not, we raise an error. + if val_object_classes: + if req_object_class: + if val_object_classes["object_class"] != req_object_class: + raise ValueError( + f"object_class must match val's defined object class: {val_object_classes['object_class']} != {req_object_class}" + ) + else: + # Note: instead of passing here, it is reasonable to conclude that we should instead raise an error - + # effectively disallowing the user from providing both a requested object class and a class hierarchy inside the payload. + pass + # In this case, we assume that the object is valid and do not need to process it. + # This would happen in practice if the user is editing an existing object by simply modifying the keys. + return ProcessIncomingObjectResult( + val=val, base_object_class=val_object_classes["base_object_class"] + ) + + # Next, we check if the user provided an object class. If they did, we need to validate the object + # and set the correct bases information. This is an important case: the user is asking us to ensure that they payload is valid and + # stored correctly. We need to validate the payload and write the correct bases information. + if req_object_class is not None: + if object_class_type := BASE_OBJECT_REGISTRY.get(req_object_class): + # TODO: in the next iteration of this code path, this is where we need to actually publish the object + # using the weave publish API instead of just dumping it. + dict_val = dump_object(object_class_type.model_validate(val)) + new_val_object_classes = get_object_classes(dict_val) + if not new_val_object_classes: + raise ValueError( + f"Unexpected error: could not get object classes for {dict_val}" + ) + if new_val_object_classes["object_class"] != req_object_class: + raise ValueError( + f"Unexpected error: base object class does not match requested object class: {new_val_object_classes['object_class']} != {req_object_class}" + ) + return ProcessIncomingObjectResult( + val=dict_val, + base_object_class=new_val_object_classes["base_object_class"], + ) + else: + raise ValueError(f"Unknown object class: {req_object_class}") + + # Finally, if there is no requested object class, just return the object as is. + return ProcessIncomingObjectResult(val=val, base_object_class=None) + + +# Server-side version of `pydantic_object_record` +def dump_object(val: BaseModel) -> dict: + cls = val.__class__ + cls_name = val.__class__.__name__ + bases = [c.__name__ for c in cls.mro()[1:-1]] + + dump = {} + # Order matters here due to the way we calculate the digest! + # This matches the client + dump["_type"] = cls_name + for k in val.model_fields: + dump[k] = _general_dump(getattr(val, k)) + # yes, this is done twice, to match the client + dump["_class_name"] = cls_name + dump["_bases"] = bases + return dump + + +def _general_dump(val: Any) -> Any: + if isinstance(val, BaseModel): + return dump_object(val) + elif isinstance(val, dict): + return {k: _general_dump(v) for k, v in val.items()} + elif isinstance(val, list): + return [_general_dump(v) for v in val] + elif isinstance(val, tuple): + return tuple(_general_dump(v) for v in val) + elif isinstance(val, set): + return {_general_dump(v) for v in val} + else: + return val diff --git a/weave/trace_server/sqlite_trace_server.py b/weave/trace_server/sqlite_trace_server.py index e14cef8041e..13ce1e51ba3 100644 --- a/weave/trace_server/sqlite_trace_server.py +++ b/weave/trace_server/sqlite_trace_server.py @@ -14,7 +14,6 @@ from weave.trace_server import refs_internal as ri from weave.trace_server import trace_server_interface as tsi -from weave.trace_server.base_object_class_util import process_incoming_object_val from weave.trace_server.emoji_util import detone_emojis from weave.trace_server.errors import InvalidRequest from weave.trace_server.feedback import ( @@ -24,6 +23,7 @@ ) from weave.trace_server.ids import generate_id from weave.trace_server.interface import query as tsi_query +from weave.trace_server.object_class_util import process_incoming_object_val from weave.trace_server.orm import Row, quote_json_path from weave.trace_server.trace_server_common import ( digest_is_version_like, @@ -611,10 +611,11 @@ def ops_query(self, req: tsi.OpQueryReq) -> tsi.OpQueryRes: def obj_create(self, req: tsi.ObjCreateReq) -> tsi.ObjCreateRes: conn, cursor = get_conn_cursor(self.db_path) - val, base_object_class = process_incoming_object_val( - req.obj.val, req.obj.set_base_object_class + processed_result = process_incoming_object_val( + req.obj.val, req.obj.object_class ) - json_val = json.dumps(val) + processed_val = processed_result["val"] + json_val = json.dumps(processed_val) digest = str_digest(json_val) # Validate @@ -652,8 +653,8 @@ def obj_create(self, req: tsi.ObjCreateReq) -> tsi.ObjCreateRes: req.obj.project_id, req.obj.object_id, datetime.datetime.now().isoformat(), - get_kind(val), - base_object_class, + get_kind(processed_val), + processed_result["base_object_class"], json.dumps([]), json_val, digest, diff --git a/weave/trace_server/trace_server_interface.py b/weave/trace_server/trace_server_interface.py index 77ef6198ecf..2f80fc64a61 100644 --- a/weave/trace_server/trace_server_interface.py +++ b/weave/trace_server/trace_server_interface.py @@ -191,7 +191,16 @@ class ObjSchemaForInsert(BaseModel): project_id: str object_id: str val: Any - set_base_object_class: Optional[str] = None + object_class: Optional[str] = None + # Keeping `set_base_object_class` here until it is successfully removed from UI client + set_base_object_class: Optional[str] = Field( + include=False, default=None, deprecated=True + ) + + def model_post_init(self, __context: Any) -> None: + # If set_base_object_class is provided, use it to set object_class for backwards compatibility + if self.set_base_object_class is not None and self.object_class is None: + self.object_class = self.set_base_object_class class TableSchemaForInsert(BaseModel):