diff --git a/tests/trace/builtin_objects/test_builtin_model.py b/tests/trace/builtin_objects/test_builtin_model.py index efb2e3c3456..17e2d977f02 100644 --- a/tests/trace/builtin_objects/test_builtin_model.py +++ b/tests/trace/builtin_objects/test_builtin_model.py @@ -4,11 +4,27 @@ from weave.trace.weave_client import WeaveClient from weave.trace_server import trace_server_interface as tsi +model_args = dict( + model="gpt-4o", + messages_template=[{"role": "user", "content": "{input}"}], + response_format={ + "type": "json_schema", + "json_schema": { + "name": "Person", + "schema": { + "type": "object", + "properties": { + "age": {"type": "integer"}, + "name": {"type": "string"}, + }, + }, + }, + }, +) + def test_publishing_alignment(client: WeaveClient): - model = LiteLLMCompletionModel( - model="gpt-4o", messages_template=[{"role": "user", "content": "{input}"}] - ) + model = LiteLLMCompletionModel(**model_args) publish_ref = weave.publish(model) obj_create_res = client.server.obj_create( @@ -17,10 +33,7 @@ def test_publishing_alignment(client: WeaveClient): "obj": { "project_id": client._project_id(), "object_id": "LiteLLMCompletionModel", - "val": { - "model": "gpt-4o", - "messages_template": [{"role": "user", "content": "{input}"}], - }, + "val": model_args, "set_leaf_object_class": "LiteLLMCompletionModel", } } @@ -31,31 +44,13 @@ def test_publishing_alignment(client: WeaveClient): def test_local_create_local_use(client: WeaveClient): - model = LiteLLMCompletionModel( - model="gpt-4o", - messages_template=[{"role": "user", "content": "{input}"}], - response_format={ - "type": "json_schema", - "json_schema": { - "name": "Person", - "schema": { - "type": "object", - "properties": { - "age": {"type": "integer"}, - "name": {"type": "string"}, - }, - }, - }, - }, - ) + model = LiteLLMCompletionModel(**model_args) predict_result = model.predict(input="My name is Carlos and I am 42 years old.") assert predict_result == {"age": 42, "name": "Carlos"} def test_local_create_remote_use(client: WeaveClient): - model = LiteLLMCompletionModel( - model="gpt-4o", messages_template=[{"role": "user", "content": "{input}"}] - ) + model = LiteLLMCompletionModel(**model_args) publish_ref = weave.publish(model) remote_call_res = client.server.call_method( tsi.CallMethodReq.model_validate( @@ -87,10 +82,7 @@ def test_remote_create_local_use(client: WeaveClient): "obj": { "project_id": client._project_id(), "object_id": "LiteLLMCompletionModel", - "val": { - "model": "gpt-4o", - "messages_template": [{"role": "user", "content": "{input}"}], - }, + "val": model_args, "set_leaf_object_class": "LiteLLMCompletionModel", } } @@ -114,10 +106,7 @@ def test_remote_create_remote_use(client: WeaveClient): "obj": { "project_id": client._project_id(), "object_id": "LiteLLMCompletionModel", - "val": { - "model": "gpt-4o", - "messages_template": [{"role": "user", "content": "{input}"}], - }, + "val": model_args, "set_leaf_object_class": "LiteLLMCompletionModel", } }