From 083088951babf15374677864d6963c4b438b406c Mon Sep 17 00:00:00 2001 From: Adrian Swanberg Date: Tue, 10 Dec 2024 15:24:11 -0800 Subject: [PATCH 1/8] feat(weave): Add 'tracing_sampling_rate' param to weave.op --- tests/trace/test_client_trace.py | 115 +++++++++++++++++++++++++++- weave/trace/context/call_context.py | 21 +++++ weave/trace/op.py | 83 ++++++++++++++------ 3 files changed, 191 insertions(+), 28 deletions(-) diff --git a/tests/trace/test_client_trace.py b/tests/trace/test_client_trace.py index 5b45abc8432..436761cb848 100644 --- a/tests/trace/test_client_trace.py +++ b/tests/trace/test_client_trace.py @@ -269,7 +269,7 @@ def adder(a: Number) -> Number: adder_v0 = adder - @weave.op() + @weave.op() # type: ignore def adder(a: Number, b) -> Number: return Number(value=a.value + b) @@ -2127,7 +2127,7 @@ def calculate(a: int, b: int) -> int: def test_call_query_stream_columns(client): @weave.op - def calculate(a: int, b: int) -> int: + def calculate(a: int, b: int) -> dict[str, Any]: return {"result": {"a + b": a + b}, "not result": 123} for i in range(2): @@ -2170,7 +2170,7 @@ def test_call_query_stream_columns_with_costs(client): return @weave.op - def calculate(a: int, b: int) -> int: + def calculate(a: int, b: int) -> dict[str, Any]: return { "result": {"a + b": a + b}, "not result": 123, @@ -2272,7 +2272,8 @@ def test_obj(val): # Ref at A, B and C test_op( - values[7], {"a": test_obj({"b": test_obj({"c": test_obj({"d": values[7]})})})} + values[7], + {"a": test_obj({"b": test_obj({"c": test_obj({"d": values[7]})})})}, ) for first, last, sort_by in [ @@ -2997,3 +2998,109 @@ def foo(): foo() assert len(list(weave_client.get_calls())) == 1 assert weave.trace.weave_init._current_inited_client is None + + +def test_op_sampling(client): + never_traced_calls = 0 + always_traced_calls = 0 + sometimes_traced_calls = 0 + + @weave.op(tracing_sample_rate=0.0) + def never_traced(x: int) -> int: + nonlocal never_traced_calls + never_traced_calls += 1 + return x + 1 + + @weave.op(tracing_sample_rate=1.0) + def always_traced(x: int) -> int: + nonlocal always_traced_calls + always_traced_calls += 1 + return x + 1 + + @weave.op(tracing_sample_rate=0.5) + def sometimes_traced(x: int) -> int: + nonlocal sometimes_traced_calls + sometimes_traced_calls += 1 + return x + 1 + + # Never traced should execute but not be traced + for i in range(10): + never_traced(i) + assert never_traced_calls == 10 # Function was called + # NOTE: We can't assert here that never_traced.calls() is empty because that call requires + # the op to be published. If we never trace, we never publish the op. + assert "call_start" not in client.server.attribute_access_log + + # Always traced should execute and be traced + for i in range(10): + always_traced(i) + assert always_traced_calls == 10 # Function was called + assert len(list(always_traced.calls())) == 10 # And traced + # Sanity check that the call_start was logged, unlike in the never_traced case. + assert "call_start" in client.server.attribute_access_log + + # Sometimes traced should execute always but only be traced sometimes + num_runs = 100 + for i in range(num_runs): + sometimes_traced(i) + assert sometimes_traced_calls == num_runs # Function was called every time + num_traces = len(list(sometimes_traced.calls())) + assert 40 < num_traces < 60 # But only traced ~50% of the time + + +def test_op_sampling_inheritance(client): + parent_calls = 0 + child_calls = 0 + + @weave.op() + def child_op(x: int) -> int: + nonlocal child_calls + child_calls += 1 + return x + 1 + + @weave.op(tracing_sample_rate=0.0) + def parent_op(x: int) -> int: + nonlocal parent_calls + parent_calls += 1 + return child_op(x) + + # When parent is sampled out, child should still execute but not be traced + for i in range(10): + parent_op(i) + + assert parent_calls == 10 # Parent function executed + assert child_calls == 10 # Child function executed + assert ( + "call_start" not in client.server.attribute_access_log + ) # But neither was traced + + # Reset counters + child_calls = 0 + + # Direct calls to child should execute and be traced + for i in range(10): + child_op(i) + + assert child_calls == 10 # Child function executed + assert len(list(child_op.calls())) == 10 # And was traced + assert "call_start" in client.server.attribute_access_log # Verify tracing occurred + + +def test_op_sampling_invalid_rates(client): + with pytest.raises(ValueError): + + @weave.op(tracing_sample_rate=-0.5) + def negative_rate(): + pass + + with pytest.raises(ValueError): + + @weave.op(tracing_sample_rate=1.5) + def too_high_rate(): + pass + + with pytest.raises(ValueError): + + @weave.op(tracing_sample_rate="invalid") # type: ignore + def invalid_type(): + pass diff --git a/weave/trace/context/call_context.py b/weave/trace/context/call_context.py index 402e1843ade..3a03bd167c3 100644 --- a/weave/trace/context/call_context.py +++ b/weave/trace/context/call_context.py @@ -20,6 +20,8 @@ class NoCurrentCallError(Exception): ... logger = logging.getLogger(__name__) +_tracing_enabled = contextvars.ContextVar("tracing_enabled", default=True) + def push_call(call: Call) -> None: new_stack = copy.copy(_call_stack.get()) @@ -136,3 +138,22 @@ def set_call_stack(stack: list[Call]) -> Iterator[list[Call]]: call_attributes: contextvars.ContextVar[dict[str, Any]] = contextvars.ContextVar( "call_attributes", default={} ) + + +def get_tracing_enabled() -> bool: + return _tracing_enabled.get() + + +@contextlib.contextmanager +def set_tracing_enabled(enabled: bool) -> Iterator[None]: + token = _tracing_enabled.set(enabled) + try: + yield + finally: + _tracing_enabled.reset(token) + + +@contextlib.contextmanager +def tracing_disabled() -> Iterator[None]: + with set_tracing_enabled(False): + yield diff --git a/weave/trace/op.py b/weave/trace/op.py index 2b5835474d8..67230a4ca3d 100644 --- a/weave/trace/op.py +++ b/weave/trace/op.py @@ -4,6 +4,7 @@ import inspect import logging +import random import sys import traceback from collections.abc import Coroutine, Mapping @@ -26,7 +27,11 @@ from weave.trace.constants import TRACE_CALL_EMOJI from weave.trace.context import call_context from weave.trace.context import weave_client_context as weave_client_context -from weave.trace.context.call_context import call_attributes +from weave.trace.context.call_context import ( + call_attributes, + get_tracing_enabled, + tracing_disabled, +) from weave.trace.context.tests_context import get_raise_on_captured_errors from weave.trace.errors import OpCallError from weave.trace.refs import ObjectRef @@ -174,6 +179,8 @@ class Op(Protocol): # it disables child ops as well. _tracing_enabled: bool + tracing_sample_rate: float + def _set_on_input_handler(func: Op, on_input: OnInputHandlerType) -> None: if func._on_input_handler is not None: @@ -401,22 +408,33 @@ def _do_call( func = op.resolve_fn call = _placeholder_call() - pargs = None - if op._on_input_handler is not None: - pargs = op._on_input_handler(op, args, kwargs) - if not pargs: - pargs = _default_on_input_handler(op, args, kwargs) + pargs = ( + op._on_input_handler(op, args, kwargs) + if op._on_input_handler + else _default_on_input_handler(op, args, kwargs) + ) - if settings.should_disable_weave(): - res = func(*pargs.args, **pargs.kwargs) - elif weave_client_context.get_weave_client() is None: - res = func(*pargs.args, **pargs.kwargs) - elif not op._tracing_enabled: + skip_tracing = ( + settings.should_disable_weave() + or weave_client_context.get_weave_client() is None + or not op._tracing_enabled + or not get_tracing_enabled() + ) + + if skip_tracing: res = func(*pargs.args, **pargs.kwargs) else: + current_call = call_context.get_current_call() + if current_call is None: + # Root call: decide whether to trace based on sample rate + if random.random() > op.tracing_sample_rate: + # Disable tracing for this call and all descendants + with tracing_disabled(): + res = func(*pargs.args, **pargs.kwargs) + return res, call + + # Proceed with tracing try: - # This try/except allows us to fail gracefully and - # still let the user code continue to execute call = _create_call(op, *args, __weave=__weave, **kwargs) except OpCallError as e: raise e @@ -436,7 +454,6 @@ def _do_call( raise TypeError( "Internal error: Expected `_execute_call` to return a sync result" ) - execute_result = cast(tuple[Any, "Call"], execute_result) res, call = execute_result return res, call @@ -450,16 +467,30 @@ async def _do_call_async( ) -> tuple[Any, Call]: func = op.resolve_fn call = _placeholder_call() - if settings.should_disable_weave(): - res = await func(*args, **kwargs) - elif weave_client_context.get_weave_client() is None: - res = await func(*args, **kwargs) - elif not op._tracing_enabled: + + if ( + settings.should_disable_weave() + or weave_client_context.get_weave_client() is None + or not op._tracing_enabled + ): res = await func(*args, **kwargs) else: + current_call = call_context.get_current_call() + tracing_enabled = get_tracing_enabled() + if current_call is None: + # Root call: decide whether to trace based on sample rate + if random.random() > op.tracing_sample_rate: + # Disable tracing for this call and all descendants + with tracing_disabled(): + res = await func(*args, **kwargs) + return res, call + elif not tracing_enabled: + # Tracing is disabled in the context + res = await func(*args, **kwargs) + return res, call + + # Proceed with tracing try: - # This try/except allows us to fail gracefully and - # still let the user code continue to execute call = _create_call(op, *args, __weave=__weave, **kwargs) except OpCallError as e: raise e @@ -479,9 +510,6 @@ async def _do_call_async( raise TypeError( "Internal error: Expected `_execute_call` to return a coroutine" ) - execute_result = cast( - Coroutine[Any, Any, tuple[Any, "Call"]], execute_result - ) res, call = await execute_result return res, call @@ -540,6 +568,7 @@ def op( call_display_name: str | CallDisplayNameFunc | None = None, postprocess_inputs: PostprocessInputsFunc | None = None, postprocess_output: PostprocessOutputFunc | None = None, + tracing_sample_rate: float = 1.0, ) -> Callable[[Callable], Op] | Op: """ A decorator to weave op-ify a function or method. Works for both sync and async. @@ -565,6 +594,7 @@ def op( postprocess_output (Optional[Callable[..., Any]]): A function to process the output after it's been returned from the function but before it's logged. This does not affect the actual output of the function, only the displayed output. + tracing_sample_rate (float): The sampling rate for tracing this function. Defaults to 1.0 (always trace). Returns: Union[Callable[[Any], Op], Op]: If called without arguments, returns a decorator. @@ -591,6 +621,10 @@ async def extract(): await extract() # calls the function and tracks the call in the Weave UI ``` """ + if not isinstance(tracing_sample_rate, (int, float)): + raise ValueError("tracing_sample_rate must be a float") + if not 0 <= tracing_sample_rate <= 1: + raise ValueError("tracing_sample_rate must be between 0 and 1") def op_deco(func: Callable) -> Op: # Check function type @@ -647,6 +681,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: wrapper._on_finish_handler = None # type: ignore wrapper._tracing_enabled = True # type: ignore + wrapper.tracing_sample_rate = tracing_sample_rate # type: ignore wrapper.get_captured_code = partial(get_captured_code, wrapper) # type: ignore From 05bc825659ee97938931e95dda78ee3047b90d08 Mon Sep 17 00:00:00 2001 From: Adrian Swanberg Date: Tue, 10 Dec 2024 15:48:22 -0800 Subject: [PATCH 2/8] reconcile async impl and add async tests --- tests/trace/test_client_trace.py | 89 ++++++++++++++++++++++++++++++-- weave/trace/op.py | 28 +++++----- 2 files changed, 102 insertions(+), 15 deletions(-) diff --git a/tests/trace/test_client_trace.py b/tests/trace/test_client_trace.py index 436761cb848..26d0618a855 100644 --- a/tests/trace/test_client_trace.py +++ b/tests/trace/test_client_trace.py @@ -3048,6 +3048,53 @@ def sometimes_traced(x: int) -> int: assert 40 < num_traces < 60 # But only traced ~50% of the time +def test_op_sampling_async(client): + never_traced_calls = 0 + always_traced_calls = 0 + sometimes_traced_calls = 0 + + @weave.op(tracing_sample_rate=0.0) + async def never_traced(x: int) -> int: + nonlocal never_traced_calls + never_traced_calls += 1 + return x + 1 + + @weave.op(tracing_sample_rate=1.0) + async def always_traced(x: int) -> int: + nonlocal always_traced_calls + always_traced_calls += 1 + return x + 1 + + @weave.op(tracing_sample_rate=0.5) + async def sometimes_traced(x: int) -> int: + nonlocal sometimes_traced_calls + sometimes_traced_calls += 1 + return x + 1 + + import asyncio + + # Never traced should execute but not be traced + for i in range(10): + asyncio.run(never_traced(i)) + assert never_traced_calls == 10 # Function was called + assert "call_start" not in client.server.attribute_access_log + + # Always traced should execute and be traced + for i in range(10): + asyncio.run(always_traced(i)) + assert always_traced_calls == 10 # Function was called + assert len(list(always_traced.calls())) == 10 # And traced + assert "call_start" in client.server.attribute_access_log + + # Sometimes traced should execute always but only be traced sometimes + num_runs = 100 + for i in range(num_runs): + asyncio.run(sometimes_traced(i)) + assert sometimes_traced_calls == num_runs # Function was called every time + num_traces = len(list(sometimes_traced.calls())) + assert 40 < num_traces < 60 # But only traced ~50% of the time + + def test_op_sampling_inheritance(client): parent_calls = 0 child_calls = 0 @@ -3070,9 +3117,7 @@ def parent_op(x: int) -> int: assert parent_calls == 10 # Parent function executed assert child_calls == 10 # Child function executed - assert ( - "call_start" not in client.server.attribute_access_log - ) # But neither was traced + assert "call_start" not in client.server.attribute_access_log # Reset counters child_calls = 0 @@ -3086,6 +3131,44 @@ def parent_op(x: int) -> int: assert "call_start" in client.server.attribute_access_log # Verify tracing occurred +def test_op_sampling_inheritance_async(client): + parent_calls = 0 + child_calls = 0 + + @weave.op() + async def child_op(x: int) -> int: + nonlocal child_calls + child_calls += 1 + return x + 1 + + @weave.op(tracing_sample_rate=0.0) + async def parent_op(x: int) -> int: + nonlocal parent_calls + parent_calls += 1 + return await child_op(x) + + import asyncio + + # When parent is sampled out, child should still execute but not be traced + for i in range(10): + asyncio.run(parent_op(i)) + + assert parent_calls == 10 # Parent function executed + assert child_calls == 10 # Child function executed + assert "call_start" not in client.server.attribute_access_log + + # Reset counters + child_calls = 0 + + # Direct calls to child should execute and be traced + for i in range(10): + asyncio.run(child_op(i)) + + assert child_calls == 10 # Child function executed + assert len(list(child_op.calls())) == 10 # And was traced + assert "call_start" in client.server.attribute_access_log # Verify tracing occurred + + def test_op_sampling_invalid_rates(client): with pytest.raises(ValueError): diff --git a/weave/trace/op.py b/weave/trace/op.py index 67230a4ca3d..ba26791068c 100644 --- a/weave/trace/op.py +++ b/weave/trace/op.py @@ -468,26 +468,30 @@ async def _do_call_async( func = op.resolve_fn call = _placeholder_call() - if ( + pargs = ( + op._on_input_handler(op, args, kwargs) + if op._on_input_handler + else _default_on_input_handler(op, args, kwargs) + ) + + skip_tracing = ( settings.should_disable_weave() or weave_client_context.get_weave_client() is None or not op._tracing_enabled - ): - res = await func(*args, **kwargs) + or not get_tracing_enabled() + ) + + if skip_tracing: + res = await func(*pargs.args, **pargs.kwargs) else: current_call = call_context.get_current_call() - tracing_enabled = get_tracing_enabled() if current_call is None: # Root call: decide whether to trace based on sample rate if random.random() > op.tracing_sample_rate: # Disable tracing for this call and all descendants with tracing_disabled(): - res = await func(*args, **kwargs) - return res, call - elif not tracing_enabled: - # Tracing is disabled in the context - res = await func(*args, **kwargs) - return res, call + res = await func(*pargs.args, **pargs.kwargs) + return res, call # Proceed with tracing try: @@ -501,10 +505,10 @@ async def _do_call_async( logger.error, ASYNC_CALL_CREATE_MSG.format(traceback.format_exc()), ) - res = await func(*args, **kwargs) + res = await func(*pargs.args, **pargs.kwargs) else: execute_result = _execute_op( - op, call, *args, __should_raise=__should_raise, **kwargs + op, call, *pargs.args, __should_raise=__should_raise, **pargs.kwargs ) if not inspect.iscoroutine(execute_result): raise TypeError( From 8880ee28a453ffa124e939167b503d24f8631b7d Mon Sep 17 00:00:00 2001 From: Adrian Swanberg Date: Tue, 10 Dec 2024 16:19:51 -0800 Subject: [PATCH 3/8] fix linter errors --- weave/trace/op.py | 27 +++++++++++---------------- 1 file changed, 11 insertions(+), 16 deletions(-) diff --git a/weave/trace/op.py b/weave/trace/op.py index ba26791068c..1db49ab2303 100644 --- a/weave/trace/op.py +++ b/weave/trace/op.py @@ -408,11 +408,11 @@ def _do_call( func = op.resolve_fn call = _placeholder_call() - pargs = ( - op._on_input_handler(op, args, kwargs) - if op._on_input_handler - else _default_on_input_handler(op, args, kwargs) - ) + pargs = None + if op._on_input_handler is not None: + pargs = op._on_input_handler(op, args, kwargs) + if not pargs: + pargs = _default_on_input_handler(op, args, kwargs) skip_tracing = ( settings.should_disable_weave() @@ -454,6 +454,7 @@ def _do_call( raise TypeError( "Internal error: Expected `_execute_call` to return a sync result" ) + execute_result = cast(tuple[Any, "Call"], execute_result) res, call = execute_result return res, call @@ -468,12 +469,6 @@ async def _do_call_async( func = op.resolve_fn call = _placeholder_call() - pargs = ( - op._on_input_handler(op, args, kwargs) - if op._on_input_handler - else _default_on_input_handler(op, args, kwargs) - ) - skip_tracing = ( settings.should_disable_weave() or weave_client_context.get_weave_client() is None @@ -482,7 +477,7 @@ async def _do_call_async( ) if skip_tracing: - res = await func(*pargs.args, **pargs.kwargs) + res = await func(*args, **kwargs) else: current_call = call_context.get_current_call() if current_call is None: @@ -490,7 +485,7 @@ async def _do_call_async( if random.random() > op.tracing_sample_rate: # Disable tracing for this call and all descendants with tracing_disabled(): - res = await func(*pargs.args, **pargs.kwargs) + res = await func(*args, **kwargs) return res, call # Proceed with tracing @@ -505,10 +500,10 @@ async def _do_call_async( logger.error, ASYNC_CALL_CREATE_MSG.format(traceback.format_exc()), ) - res = await func(*pargs.args, **pargs.kwargs) + res = await func(*args, **kwargs) else: execute_result = _execute_op( - op, call, *pargs.args, __should_raise=__should_raise, **pargs.kwargs + op, call, *args, __should_raise=__should_raise, **kwargs ) if not inspect.iscoroutine(execute_result): raise TypeError( @@ -626,7 +621,7 @@ async def extract(): ``` """ if not isinstance(tracing_sample_rate, (int, float)): - raise ValueError("tracing_sample_rate must be a float") + raise TypeError("tracing_sample_rate must be a float") if not 0 <= tracing_sample_rate <= 1: raise ValueError("tracing_sample_rate must be between 0 and 1") From cf553a2f83db92e8e17c802ca79aaa1c37b12fa7 Mon Sep 17 00:00:00 2001 From: Adrian Swanberg Date: Tue, 10 Dec 2024 21:00:18 -0800 Subject: [PATCH 4/8] fix test --- tests/trace/test_client_trace.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/trace/test_client_trace.py b/tests/trace/test_client_trace.py index 26d0618a855..b7a06640cba 100644 --- a/tests/trace/test_client_trace.py +++ b/tests/trace/test_client_trace.py @@ -3182,7 +3182,7 @@ def negative_rate(): def too_high_rate(): pass - with pytest.raises(ValueError): + with pytest.raises(TypeError): @weave.op(tracing_sample_rate="invalid") # type: ignore def invalid_type(): From 0cd0fc2cd971309a44f9dda358c29a80cf6cedf8 Mon Sep 17 00:00:00 2001 From: Adrian Swanberg Date: Wed, 11 Dec 2024 12:58:58 -0800 Subject: [PATCH 5/8] address review comments --- tests/trace/test_client_trace.py | 68 ++++++------- weave/trace/op.py | 160 ++++++++++++++++--------------- 2 files changed, 118 insertions(+), 110 deletions(-) diff --git a/tests/trace/test_client_trace.py b/tests/trace/test_client_trace.py index b7a06640cba..04808a2d1be 100644 --- a/tests/trace/test_client_trace.py +++ b/tests/trace/test_client_trace.py @@ -60,7 +60,7 @@ def get_client_project_id(client: weave_client.WeaveClient) -> str: def test_simple_op(client): - @weave.op() + @weave.op def my_op(a: int) -> int: return a + 1 @@ -229,7 +229,7 @@ def test_call_read_not_found(client): def test_graph_call_ordering(client): - @weave.op() + @weave.op def my_op(a: int) -> int: return a + 1 @@ -263,27 +263,27 @@ def simple_line_call_bootstrap(init_wandb: bool = False) -> OpCallSpec: class Number(weave.Object): value: int - @weave.op() + @weave.op def adder(a: Number) -> Number: return Number(value=a.value + a.value) adder_v0 = adder - @weave.op() # type: ignore + @weave.op # type: ignore def adder(a: Number, b) -> Number: return Number(value=a.value + b) - @weave.op() + @weave.op def subtractor(a: Number, b) -> Number: return Number(value=a.value - b) - @weave.op() + @weave.op def multiplier( a: Number, b ) -> int: # intentionally deviant in returning plain int - so that we have a different type return a.value * b - @weave.op() + @weave.op def liner(m: Number, b, x) -> Number: return adder(Number(value=multiplier(m, x)), b) @@ -691,7 +691,7 @@ def test_trace_call_query_offset(client): def test_trace_call_sort(client): - @weave.op() + @weave.op def basic_op(in_val: dict, delay) -> dict: import time @@ -727,7 +727,7 @@ def test_trace_call_sort_with_mixed_types(client): # SQLite does not support sorting over mixed types in a column, so we skip this test return - @weave.op() + @weave.op def basic_op(in_val: dict) -> dict: import time @@ -769,7 +769,7 @@ def basic_op(in_val: dict) -> dict: def test_trace_call_filter(client): is_sqlite = client_is_sqlite(client) - @weave.op() + @weave.op def basic_op(in_val: dict, delay) -> dict: return in_val @@ -1160,7 +1160,7 @@ def basic_op(in_val: dict, delay) -> dict: def test_ops_with_default_params(client): - @weave.op() + @weave.op def op_with_default(a: int, b: int = 10) -> int: return a + b @@ -1234,7 +1234,7 @@ class BaseTypeC(BaseTypeB): def test_attributes_on_ops(client): - @weave.op() + @weave.op def op_with_attrs(a: int, b: int) -> int: return a + b @@ -1277,7 +1277,7 @@ def test_dataclass_support(client): class MyDataclass: val: int - @weave.op() + @weave.op def dataclass_maker(a: MyDataclass, b: MyDataclass) -> MyDataclass: return MyDataclass(a.val + b.val) @@ -1322,7 +1322,7 @@ def dataclass_maker(a: MyDataclass, b: MyDataclass) -> MyDataclass: def test_op_retrieval(client): - @weave.op() + @weave.op def my_op(a: int) -> int: return a + 1 @@ -1336,7 +1336,7 @@ def test_bound_op_retrieval(client): class CustomType(weave.Object): a: int - @weave.op() + @weave.op def op_with_custom_type(self, v): return self.a + v @@ -1359,7 +1359,7 @@ def test_bound_op_retrieval_no_self(client): class CustomTypeWithoutSelf(weave.Object): a: int - @weave.op() + @weave.op def op_with_custom_type(me, v): return me.a + v @@ -1387,7 +1387,7 @@ def test_dataset_row_ref(client): def test_tuple_support(client): - @weave.op() + @weave.op def tuple_maker(a, b): return (a, b) @@ -1411,7 +1411,7 @@ def tuple_maker(a, b): def test_namedtuple_support(client): - @weave.op() + @weave.op def tuple_maker(a, b): return (a, b) @@ -1442,7 +1442,7 @@ def test_named_reuse(client): d_ref = weave.publish(d, "test_dataset") dataset = weave.ref(d_ref.uri()).get() - @weave.op() + @weave.op async def dummy_score(output): return 1 @@ -1489,7 +1489,7 @@ class MyUnknownClassB: def __init__(self, b_val) -> None: self.b_val = b_val - @weave.op() + @weave.op def op_with_unknown_types(a: MyUnknownClassA, b: float) -> MyUnknownClassB: return MyUnknownClassB(a.a_val + b) @@ -1564,19 +1564,19 @@ def init_weave_get_server_patched(api_key): def test_single_primitive_output(client): - @weave.op() + @weave.op def single_int_output(a: int) -> int: return a - @weave.op() + @weave.op def single_bool_output(a: int) -> bool: return a == 1 - @weave.op() + @weave.op def single_none_output(a: int) -> None: return None - @weave.op() + @weave.op def dict_output(a: int, b: bool, c: None) -> dict: return {"a": a, "b": b, "c": c} @@ -1669,14 +1669,14 @@ def test_mapped_execution(client, mapper): events = [] - @weave.op() + @weave.op def op_a(a: int) -> int: events.append("A(S):" + str(a)) time.sleep(0.3) events.append("A(E):" + str(a)) return a - @weave.op() + @weave.op def op_b(b: int) -> int: events.append("B(S):" + str(b)) time.sleep(0.2) @@ -1684,7 +1684,7 @@ def op_b(b: int) -> int: events.append("B(E):" + str(b)) return res - @weave.op() + @weave.op def op_c(c: int) -> int: events.append("C(S):" + str(c)) time.sleep(0.1) @@ -1692,7 +1692,7 @@ def op_c(c: int) -> int: events.append("C(E):" + str(c)) return res - @weave.op() + @weave.op def op_mapper(vals): return mapper(op_c, vals) @@ -2238,7 +2238,7 @@ def calculate(a: int, b: int) -> dict[str, Any]: @pytest.mark.skip("Not implemented: filter / sort through refs") def test_sort_and_filter_through_refs(client): - @weave.op() + @weave.op def test_op(label, val): return val @@ -2356,7 +2356,7 @@ def test_obj(val): def test_in_operation(client): - @weave.op() + @weave.op def test_op(label, val): return val @@ -2501,7 +2501,7 @@ def func(x): class BasicModel(weave.Model): - @weave.op() + @weave.op def predict(self, x): return {"answer": "42"} @@ -2547,7 +2547,7 @@ class SimpleObject(weave.Object): class NestedObject(weave.Object): b: SimpleObject - @weave.op() + @weave.op def return_nested_object(nested_obj: NestedObject): return nested_obj @@ -3099,7 +3099,7 @@ def test_op_sampling_inheritance(client): parent_calls = 0 child_calls = 0 - @weave.op() + @weave.op def child_op(x: int) -> int: nonlocal child_calls child_calls += 1 @@ -3135,7 +3135,7 @@ def test_op_sampling_inheritance_async(client): parent_calls = 0 child_calls = 0 - @weave.op() + @weave.op async def child_op(x: int) -> int: nonlocal child_calls child_calls += 1 diff --git a/weave/trace/op.py b/weave/trace/op.py index 1db49ab2303..37a6dc1c160 100644 --- a/weave/trace/op.py +++ b/weave/trace/op.py @@ -414,48 +414,52 @@ def _do_call( if not pargs: pargs = _default_on_input_handler(op, args, kwargs) - skip_tracing = ( - settings.should_disable_weave() - or weave_client_context.get_weave_client() is None - or not op._tracing_enabled - or not get_tracing_enabled() - ) - - if skip_tracing: + # Handle all of the possible cases where we would skip tracing. + if settings.should_disable_weave(): + res = func(*pargs.args, **pargs.kwargs) + return res, call + if weave_client_context.get_weave_client() is None: + res = func(*pargs.args, **pargs.kwargs) + return res, call + if not op._tracing_enabled: + res = func(*pargs.args, **pargs.kwargs) + return res, call + if not get_tracing_enabled(): + res = func(*pargs.args, **pargs.kwargs) + return res, call + + current_call = call_context.get_current_call() + if current_call is None: + # Root call: decide whether to trace based on sample rate + if random.random() > op.tracing_sample_rate: + # Disable tracing for this call and all descendants + with tracing_disabled(): + res = func(*pargs.args, **pargs.kwargs) + return res, call + + # Proceed with tracing + try: + call = _create_call(op, *args, __weave=__weave, **kwargs) + except OpCallError as e: + raise e + except Exception as e: + if get_raise_on_captured_errors(): + raise + log_once( + logger.error, + CALL_CREATE_MSG.format(traceback.format_exc()), + ) res = func(*pargs.args, **pargs.kwargs) else: - current_call = call_context.get_current_call() - if current_call is None: - # Root call: decide whether to trace based on sample rate - if random.random() > op.tracing_sample_rate: - # Disable tracing for this call and all descendants - with tracing_disabled(): - res = func(*pargs.args, **pargs.kwargs) - return res, call - - # Proceed with tracing - try: - call = _create_call(op, *args, __weave=__weave, **kwargs) - except OpCallError as e: - raise e - except Exception as e: - if get_raise_on_captured_errors(): - raise - log_once( - logger.error, - CALL_CREATE_MSG.format(traceback.format_exc()), - ) - res = func(*pargs.args, **pargs.kwargs) - else: - execute_result = _execute_op( - op, call, *pargs.args, __should_raise=__should_raise, **pargs.kwargs + execute_result = _execute_op( + op, call, *pargs.args, __should_raise=__should_raise, **pargs.kwargs + ) + if inspect.iscoroutine(execute_result): + raise TypeError( + "Internal error: Expected `_execute_call` to return a sync result" ) - if inspect.iscoroutine(execute_result): - raise TypeError( - "Internal error: Expected `_execute_call` to return a sync result" - ) - execute_result = cast(tuple[Any, "Call"], execute_result) - res, call = execute_result + execute_result = cast(tuple[Any, "Call"], execute_result) + res, call = execute_result return res, call @@ -469,47 +473,51 @@ async def _do_call_async( func = op.resolve_fn call = _placeholder_call() - skip_tracing = ( - settings.should_disable_weave() - or weave_client_context.get_weave_client() is None - or not op._tracing_enabled - or not get_tracing_enabled() - ) + # Handle all of the possible cases where we would skip tracing. + if settings.should_disable_weave(): + res = await func(*args, **kwargs) + return res, call + if weave_client_context.get_weave_client() is None: + res = await func(*args, **kwargs) + return res, call + if not op._tracing_enabled: + res = await func(*args, **kwargs) + return res, call + if not get_tracing_enabled(): + res = await func(*args, **kwargs) + return res, call + + current_call = call_context.get_current_call() + if current_call is None: + # Root call: decide whether to trace based on sample rate + if random.random() > op.tracing_sample_rate: + # Disable tracing for this call and all descendants + with tracing_disabled(): + res = await func(*args, **kwargs) + return res, call - if skip_tracing: + # Proceed with tracing + try: + call = _create_call(op, *args, __weave=__weave, **kwargs) + except OpCallError as e: + raise e + except Exception as e: + if get_raise_on_captured_errors(): + raise + log_once( + logger.error, + ASYNC_CALL_CREATE_MSG.format(traceback.format_exc()), + ) res = await func(*args, **kwargs) else: - current_call = call_context.get_current_call() - if current_call is None: - # Root call: decide whether to trace based on sample rate - if random.random() > op.tracing_sample_rate: - # Disable tracing for this call and all descendants - with tracing_disabled(): - res = await func(*args, **kwargs) - return res, call - - # Proceed with tracing - try: - call = _create_call(op, *args, __weave=__weave, **kwargs) - except OpCallError as e: - raise e - except Exception as e: - if get_raise_on_captured_errors(): - raise - log_once( - logger.error, - ASYNC_CALL_CREATE_MSG.format(traceback.format_exc()), - ) - res = await func(*args, **kwargs) - else: - execute_result = _execute_op( - op, call, *args, __should_raise=__should_raise, **kwargs + execute_result = _execute_op( + op, call, *args, __should_raise=__should_raise, **kwargs + ) + if not inspect.iscoroutine(execute_result): + raise TypeError( + "Internal error: Expected `_execute_call` to return a coroutine" ) - if not inspect.iscoroutine(execute_result): - raise TypeError( - "Internal error: Expected `_execute_call` to return a coroutine" - ) - res, call = await execute_result + res, call = await execute_result return res, call From e5dc643e4190d378f862e27f6683d8904d90e37d Mon Sep 17 00:00:00 2001 From: Adrian Swanberg Date: Wed, 11 Dec 2024 17:33:27 -0800 Subject: [PATCH 6/8] add more padding on tests that measure sample rates --- tests/trace/test_client_trace.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/trace/test_client_trace.py b/tests/trace/test_client_trace.py index 04808a2d1be..9d1fb4bb223 100644 --- a/tests/trace/test_client_trace.py +++ b/tests/trace/test_client_trace.py @@ -3045,7 +3045,7 @@ def sometimes_traced(x: int) -> int: sometimes_traced(i) assert sometimes_traced_calls == num_runs # Function was called every time num_traces = len(list(sometimes_traced.calls())) - assert 40 < num_traces < 60 # But only traced ~50% of the time + assert 35 < num_traces < 65 # But only traced ~50% of the time def test_op_sampling_async(client): @@ -3092,7 +3092,7 @@ async def sometimes_traced(x: int) -> int: asyncio.run(sometimes_traced(i)) assert sometimes_traced_calls == num_runs # Function was called every time num_traces = len(list(sometimes_traced.calls())) - assert 40 < num_traces < 60 # But only traced ~50% of the time + assert 35 < num_traces < 65 # But only traced ~50% of the time def test_op_sampling_inheritance(client): From 70b2c265fda00424964130046a621d3ce624e9bc Mon Sep 17 00:00:00 2001 From: Adrian Swanberg Date: Wed, 11 Dec 2024 21:28:39 -0800 Subject: [PATCH 7/8] add docs for sampling rate --- docs/docs/guides/tracking/ops.md | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/docs/docs/guides/tracking/ops.md b/docs/docs/guides/tracking/ops.md index b69d5d1d91a..1fb259d1b37 100644 --- a/docs/docs/guides/tracking/ops.md +++ b/docs/docs/guides/tracking/ops.md @@ -116,6 +116,37 @@ A Weave op is a versioned function that automatically logs all calls. +## Control sampling rate + + + + You can control how frequently an op's calls are traced by setting the `tracing_sample_rate` parameter in the `@weave.op` decorator. This is useful for high-frequency ops where you only need to trace a subset of calls. + + ```python + @weave.op(tracing_sample_rate=0.1) # Only trace ~10% of calls + def high_frequency_op(x: int) -> int: + return x + 1 + + @weave.op(tracing_sample_rate=1.0) # Always trace (default) + def always_traced_op(x: int) -> int: + return x + 1 + ``` + + When an op's call is not sampled: + - The function executes normally + - No trace data is sent to Weave + - Child ops are also not traced for that call + + The sampling rate must be between 0.0 and 1.0 inclusive. + + + + ```plaintext + This feature is not available in TypeScript yet. Stay tuned! + ``` + + + ### Control call link output If you want to suppress the printing of call links during logging, you can use the `WEAVE_PRINT_CALL_LINK` environment variable to `false`. This can be useful if you want to reduce output verbosity and reduce clutter in your logs. From dfffb7848d9941394db14b764ee1d04fee681c25 Mon Sep 17 00:00:00 2001 From: Adrian Swanberg Date: Thu, 12 Dec 2024 21:40:16 -0800 Subject: [PATCH 8/8] Make tests simpler and update docs --- docs/docs/guides/tracking/ops.md | 2 ++ tests/trace/test_client_trace.py | 44 +++++++++++++++++++++++++++----- weave/trace/op.py | 4 ++- 3 files changed, 43 insertions(+), 7 deletions(-) diff --git a/docs/docs/guides/tracking/ops.md b/docs/docs/guides/tracking/ops.md index 1fb259d1b37..4c1e064b0aa 100644 --- a/docs/docs/guides/tracking/ops.md +++ b/docs/docs/guides/tracking/ops.md @@ -122,6 +122,8 @@ A Weave op is a versioned function that automatically logs all calls. You can control how frequently an op's calls are traced by setting the `tracing_sample_rate` parameter in the `@weave.op` decorator. This is useful for high-frequency ops where you only need to trace a subset of calls. + Note that sampling rates are only applied to root calls. If an op has a sample rate, but is called by another op first, then that sampling rate will be ignored. + ```python @weave.op(tracing_sample_rate=0.1) # Only trace ~10% of calls def high_frequency_op(x: int) -> int: diff --git a/tests/trace/test_client_trace.py b/tests/trace/test_client_trace.py index 9d1fb4bb223..005c79f5cb0 100644 --- a/tests/trace/test_client_trace.py +++ b/tests/trace/test_client_trace.py @@ -3023,13 +3023,12 @@ def sometimes_traced(x: int) -> int: sometimes_traced_calls += 1 return x + 1 + weave.publish(never_traced) # Never traced should execute but not be traced for i in range(10): never_traced(i) assert never_traced_calls == 10 # Function was called - # NOTE: We can't assert here that never_traced.calls() is empty because that call requires - # the op to be published. If we never trace, we never publish the op. - assert "call_start" not in client.server.attribute_access_log + assert len(list(never_traced.calls())) == 0 # Not traced # Always traced should execute and be traced for i in range(10): @@ -3073,11 +3072,12 @@ async def sometimes_traced(x: int) -> int: import asyncio + weave.publish(never_traced) # Never traced should execute but not be traced for i in range(10): asyncio.run(never_traced(i)) assert never_traced_calls == 10 # Function was called - assert "call_start" not in client.server.attribute_access_log + assert len(list(never_traced.calls())) == 0 # Not traced # Always traced should execute and be traced for i in range(10): @@ -3111,13 +3111,14 @@ def parent_op(x: int) -> int: parent_calls += 1 return child_op(x) + weave.publish(parent_op) # When parent is sampled out, child should still execute but not be traced for i in range(10): parent_op(i) assert parent_calls == 10 # Parent function executed assert child_calls == 10 # Child function executed - assert "call_start" not in client.server.attribute_access_log + assert len(list(parent_op.calls())) == 0 # Parent not traced # Reset counters child_calls = 0 @@ -3149,13 +3150,14 @@ async def parent_op(x: int) -> int: import asyncio + weave.publish(parent_op) # When parent is sampled out, child should still execute but not be traced for i in range(10): asyncio.run(parent_op(i)) assert parent_calls == 10 # Parent function executed assert child_calls == 10 # Child function executed - assert "call_start" not in client.server.attribute_access_log + assert len(list(parent_op.calls())) == 0 # Parent not traced # Reset counters child_calls = 0 @@ -3187,3 +3189,33 @@ def too_high_rate(): @weave.op(tracing_sample_rate="invalid") # type: ignore def invalid_type(): pass + + +def test_op_sampling_child_follows_parent(client): + parent_calls = 0 + child_calls = 0 + + @weave.op(tracing_sample_rate=0.0) # Never traced + def child_op(x: int) -> int: + nonlocal child_calls + child_calls += 1 + return x + 1 + + @weave.op(tracing_sample_rate=1.0) # Always traced + def parent_op(x: int) -> int: + nonlocal parent_calls + parent_calls += 1 + return child_op(x) + + num_runs = 100 + for i in range(num_runs): + parent_op(i) + + assert parent_calls == num_runs # Parent was always executed + assert child_calls == num_runs # Child was always executed + + parent_traces = len(list(parent_op.calls())) + child_traces = len(list(child_op.calls())) + + assert parent_traces == num_runs # Parent was always traced + assert child_traces == num_runs # Child was traced whenever parent was diff --git a/weave/trace/op.py b/weave/trace/op.py index 37a6dc1c160..a89c7400d8b 100644 --- a/weave/trace/op.py +++ b/weave/trace/op.py @@ -437,7 +437,9 @@ def _do_call( res = func(*pargs.args, **pargs.kwargs) return res, call - # Proceed with tracing + # Proceed with tracing. Note that we don't check the sample rate here. + # Only root calls get sampling applied. + # If the parent was traced (sampled in), the child will be too. try: call = _create_call(op, *args, __weave=__weave, **kwargs) except OpCallError as e: