From 0cd0fc2cd971309a44f9dda358c29a80cf6cedf8 Mon Sep 17 00:00:00 2001 From: Adrian Swanberg Date: Wed, 11 Dec 2024 12:58:58 -0800 Subject: [PATCH] address review comments --- tests/trace/test_client_trace.py | 68 ++++++------- weave/trace/op.py | 160 ++++++++++++++++--------------- 2 files changed, 118 insertions(+), 110 deletions(-) diff --git a/tests/trace/test_client_trace.py b/tests/trace/test_client_trace.py index b7a06640cba..04808a2d1be 100644 --- a/tests/trace/test_client_trace.py +++ b/tests/trace/test_client_trace.py @@ -60,7 +60,7 @@ def get_client_project_id(client: weave_client.WeaveClient) -> str: def test_simple_op(client): - @weave.op() + @weave.op def my_op(a: int) -> int: return a + 1 @@ -229,7 +229,7 @@ def test_call_read_not_found(client): def test_graph_call_ordering(client): - @weave.op() + @weave.op def my_op(a: int) -> int: return a + 1 @@ -263,27 +263,27 @@ def simple_line_call_bootstrap(init_wandb: bool = False) -> OpCallSpec: class Number(weave.Object): value: int - @weave.op() + @weave.op def adder(a: Number) -> Number: return Number(value=a.value + a.value) adder_v0 = adder - @weave.op() # type: ignore + @weave.op # type: ignore def adder(a: Number, b) -> Number: return Number(value=a.value + b) - @weave.op() + @weave.op def subtractor(a: Number, b) -> Number: return Number(value=a.value - b) - @weave.op() + @weave.op def multiplier( a: Number, b ) -> int: # intentionally deviant in returning plain int - so that we have a different type return a.value * b - @weave.op() + @weave.op def liner(m: Number, b, x) -> Number: return adder(Number(value=multiplier(m, x)), b) @@ -691,7 +691,7 @@ def test_trace_call_query_offset(client): def test_trace_call_sort(client): - @weave.op() + @weave.op def basic_op(in_val: dict, delay) -> dict: import time @@ -727,7 +727,7 @@ def test_trace_call_sort_with_mixed_types(client): # SQLite does not support sorting over mixed types in a column, so we skip this test return - @weave.op() + @weave.op def basic_op(in_val: dict) -> dict: import time @@ -769,7 +769,7 @@ def basic_op(in_val: dict) -> dict: def test_trace_call_filter(client): is_sqlite = client_is_sqlite(client) - @weave.op() + @weave.op def basic_op(in_val: dict, delay) -> dict: return in_val @@ -1160,7 +1160,7 @@ def basic_op(in_val: dict, delay) -> dict: def test_ops_with_default_params(client): - @weave.op() + @weave.op def op_with_default(a: int, b: int = 10) -> int: return a + b @@ -1234,7 +1234,7 @@ class BaseTypeC(BaseTypeB): def test_attributes_on_ops(client): - @weave.op() + @weave.op def op_with_attrs(a: int, b: int) -> int: return a + b @@ -1277,7 +1277,7 @@ def test_dataclass_support(client): class MyDataclass: val: int - @weave.op() + @weave.op def dataclass_maker(a: MyDataclass, b: MyDataclass) -> MyDataclass: return MyDataclass(a.val + b.val) @@ -1322,7 +1322,7 @@ def dataclass_maker(a: MyDataclass, b: MyDataclass) -> MyDataclass: def test_op_retrieval(client): - @weave.op() + @weave.op def my_op(a: int) -> int: return a + 1 @@ -1336,7 +1336,7 @@ def test_bound_op_retrieval(client): class CustomType(weave.Object): a: int - @weave.op() + @weave.op def op_with_custom_type(self, v): return self.a + v @@ -1359,7 +1359,7 @@ def test_bound_op_retrieval_no_self(client): class CustomTypeWithoutSelf(weave.Object): a: int - @weave.op() + @weave.op def op_with_custom_type(me, v): return me.a + v @@ -1387,7 +1387,7 @@ def test_dataset_row_ref(client): def test_tuple_support(client): - @weave.op() + @weave.op def tuple_maker(a, b): return (a, b) @@ -1411,7 +1411,7 @@ def tuple_maker(a, b): def test_namedtuple_support(client): - @weave.op() + @weave.op def tuple_maker(a, b): return (a, b) @@ -1442,7 +1442,7 @@ def test_named_reuse(client): d_ref = weave.publish(d, "test_dataset") dataset = weave.ref(d_ref.uri()).get() - @weave.op() + @weave.op async def dummy_score(output): return 1 @@ -1489,7 +1489,7 @@ class MyUnknownClassB: def __init__(self, b_val) -> None: self.b_val = b_val - @weave.op() + @weave.op def op_with_unknown_types(a: MyUnknownClassA, b: float) -> MyUnknownClassB: return MyUnknownClassB(a.a_val + b) @@ -1564,19 +1564,19 @@ def init_weave_get_server_patched(api_key): def test_single_primitive_output(client): - @weave.op() + @weave.op def single_int_output(a: int) -> int: return a - @weave.op() + @weave.op def single_bool_output(a: int) -> bool: return a == 1 - @weave.op() + @weave.op def single_none_output(a: int) -> None: return None - @weave.op() + @weave.op def dict_output(a: int, b: bool, c: None) -> dict: return {"a": a, "b": b, "c": c} @@ -1669,14 +1669,14 @@ def test_mapped_execution(client, mapper): events = [] - @weave.op() + @weave.op def op_a(a: int) -> int: events.append("A(S):" + str(a)) time.sleep(0.3) events.append("A(E):" + str(a)) return a - @weave.op() + @weave.op def op_b(b: int) -> int: events.append("B(S):" + str(b)) time.sleep(0.2) @@ -1684,7 +1684,7 @@ def op_b(b: int) -> int: events.append("B(E):" + str(b)) return res - @weave.op() + @weave.op def op_c(c: int) -> int: events.append("C(S):" + str(c)) time.sleep(0.1) @@ -1692,7 +1692,7 @@ def op_c(c: int) -> int: events.append("C(E):" + str(c)) return res - @weave.op() + @weave.op def op_mapper(vals): return mapper(op_c, vals) @@ -2238,7 +2238,7 @@ def calculate(a: int, b: int) -> dict[str, Any]: @pytest.mark.skip("Not implemented: filter / sort through refs") def test_sort_and_filter_through_refs(client): - @weave.op() + @weave.op def test_op(label, val): return val @@ -2356,7 +2356,7 @@ def test_obj(val): def test_in_operation(client): - @weave.op() + @weave.op def test_op(label, val): return val @@ -2501,7 +2501,7 @@ def func(x): class BasicModel(weave.Model): - @weave.op() + @weave.op def predict(self, x): return {"answer": "42"} @@ -2547,7 +2547,7 @@ class SimpleObject(weave.Object): class NestedObject(weave.Object): b: SimpleObject - @weave.op() + @weave.op def return_nested_object(nested_obj: NestedObject): return nested_obj @@ -3099,7 +3099,7 @@ def test_op_sampling_inheritance(client): parent_calls = 0 child_calls = 0 - @weave.op() + @weave.op def child_op(x: int) -> int: nonlocal child_calls child_calls += 1 @@ -3135,7 +3135,7 @@ def test_op_sampling_inheritance_async(client): parent_calls = 0 child_calls = 0 - @weave.op() + @weave.op async def child_op(x: int) -> int: nonlocal child_calls child_calls += 1 diff --git a/weave/trace/op.py b/weave/trace/op.py index 1db49ab2303..37a6dc1c160 100644 --- a/weave/trace/op.py +++ b/weave/trace/op.py @@ -414,48 +414,52 @@ def _do_call( if not pargs: pargs = _default_on_input_handler(op, args, kwargs) - skip_tracing = ( - settings.should_disable_weave() - or weave_client_context.get_weave_client() is None - or not op._tracing_enabled - or not get_tracing_enabled() - ) - - if skip_tracing: + # Handle all of the possible cases where we would skip tracing. + if settings.should_disable_weave(): + res = func(*pargs.args, **pargs.kwargs) + return res, call + if weave_client_context.get_weave_client() is None: + res = func(*pargs.args, **pargs.kwargs) + return res, call + if not op._tracing_enabled: + res = func(*pargs.args, **pargs.kwargs) + return res, call + if not get_tracing_enabled(): + res = func(*pargs.args, **pargs.kwargs) + return res, call + + current_call = call_context.get_current_call() + if current_call is None: + # Root call: decide whether to trace based on sample rate + if random.random() > op.tracing_sample_rate: + # Disable tracing for this call and all descendants + with tracing_disabled(): + res = func(*pargs.args, **pargs.kwargs) + return res, call + + # Proceed with tracing + try: + call = _create_call(op, *args, __weave=__weave, **kwargs) + except OpCallError as e: + raise e + except Exception as e: + if get_raise_on_captured_errors(): + raise + log_once( + logger.error, + CALL_CREATE_MSG.format(traceback.format_exc()), + ) res = func(*pargs.args, **pargs.kwargs) else: - current_call = call_context.get_current_call() - if current_call is None: - # Root call: decide whether to trace based on sample rate - if random.random() > op.tracing_sample_rate: - # Disable tracing for this call and all descendants - with tracing_disabled(): - res = func(*pargs.args, **pargs.kwargs) - return res, call - - # Proceed with tracing - try: - call = _create_call(op, *args, __weave=__weave, **kwargs) - except OpCallError as e: - raise e - except Exception as e: - if get_raise_on_captured_errors(): - raise - log_once( - logger.error, - CALL_CREATE_MSG.format(traceback.format_exc()), - ) - res = func(*pargs.args, **pargs.kwargs) - else: - execute_result = _execute_op( - op, call, *pargs.args, __should_raise=__should_raise, **pargs.kwargs + execute_result = _execute_op( + op, call, *pargs.args, __should_raise=__should_raise, **pargs.kwargs + ) + if inspect.iscoroutine(execute_result): + raise TypeError( + "Internal error: Expected `_execute_call` to return a sync result" ) - if inspect.iscoroutine(execute_result): - raise TypeError( - "Internal error: Expected `_execute_call` to return a sync result" - ) - execute_result = cast(tuple[Any, "Call"], execute_result) - res, call = execute_result + execute_result = cast(tuple[Any, "Call"], execute_result) + res, call = execute_result return res, call @@ -469,47 +473,51 @@ async def _do_call_async( func = op.resolve_fn call = _placeholder_call() - skip_tracing = ( - settings.should_disable_weave() - or weave_client_context.get_weave_client() is None - or not op._tracing_enabled - or not get_tracing_enabled() - ) + # Handle all of the possible cases where we would skip tracing. + if settings.should_disable_weave(): + res = await func(*args, **kwargs) + return res, call + if weave_client_context.get_weave_client() is None: + res = await func(*args, **kwargs) + return res, call + if not op._tracing_enabled: + res = await func(*args, **kwargs) + return res, call + if not get_tracing_enabled(): + res = await func(*args, **kwargs) + return res, call + + current_call = call_context.get_current_call() + if current_call is None: + # Root call: decide whether to trace based on sample rate + if random.random() > op.tracing_sample_rate: + # Disable tracing for this call and all descendants + with tracing_disabled(): + res = await func(*args, **kwargs) + return res, call - if skip_tracing: + # Proceed with tracing + try: + call = _create_call(op, *args, __weave=__weave, **kwargs) + except OpCallError as e: + raise e + except Exception as e: + if get_raise_on_captured_errors(): + raise + log_once( + logger.error, + ASYNC_CALL_CREATE_MSG.format(traceback.format_exc()), + ) res = await func(*args, **kwargs) else: - current_call = call_context.get_current_call() - if current_call is None: - # Root call: decide whether to trace based on sample rate - if random.random() > op.tracing_sample_rate: - # Disable tracing for this call and all descendants - with tracing_disabled(): - res = await func(*args, **kwargs) - return res, call - - # Proceed with tracing - try: - call = _create_call(op, *args, __weave=__weave, **kwargs) - except OpCallError as e: - raise e - except Exception as e: - if get_raise_on_captured_errors(): - raise - log_once( - logger.error, - ASYNC_CALL_CREATE_MSG.format(traceback.format_exc()), - ) - res = await func(*args, **kwargs) - else: - execute_result = _execute_op( - op, call, *args, __should_raise=__should_raise, **kwargs + execute_result = _execute_op( + op, call, *args, __should_raise=__should_raise, **kwargs + ) + if not inspect.iscoroutine(execute_result): + raise TypeError( + "Internal error: Expected `_execute_call` to return a coroutine" ) - if not inspect.iscoroutine(execute_result): - raise TypeError( - "Internal error: Expected `_execute_call` to return a coroutine" - ) - res, call = await execute_result + res, call = await execute_result return res, call