diff --git a/docs/docs/guides/tracking/ops.md b/docs/docs/guides/tracking/ops.md index b69d5d1d91a..4c1e064b0aa 100644 --- a/docs/docs/guides/tracking/ops.md +++ b/docs/docs/guides/tracking/ops.md @@ -116,6 +116,39 @@ A Weave op is a versioned function that automatically logs all calls. +## Control sampling rate + + + + You can control how frequently an op's calls are traced by setting the `tracing_sample_rate` parameter in the `@weave.op` decorator. This is useful for high-frequency ops where you only need to trace a subset of calls. + + Note that sampling rates are only applied to root calls. If an op has a sample rate, but is called by another op first, then that sampling rate will be ignored. + + ```python + @weave.op(tracing_sample_rate=0.1) # Only trace ~10% of calls + def high_frequency_op(x: int) -> int: + return x + 1 + + @weave.op(tracing_sample_rate=1.0) # Always trace (default) + def always_traced_op(x: int) -> int: + return x + 1 + ``` + + When an op's call is not sampled: + - The function executes normally + - No trace data is sent to Weave + - Child ops are also not traced for that call + + The sampling rate must be between 0.0 and 1.0 inclusive. + + + + ```plaintext + This feature is not available in TypeScript yet. Stay tuned! + ``` + + + ### Control call link output If you want to suppress the printing of call links during logging, you can use the `WEAVE_PRINT_CALL_LINK` environment variable to `false`. This can be useful if you want to reduce output verbosity and reduce clutter in your logs. diff --git a/tests/trace/test_client_trace.py b/tests/trace/test_client_trace.py index 5b45abc8432..005c79f5cb0 100644 --- a/tests/trace/test_client_trace.py +++ b/tests/trace/test_client_trace.py @@ -60,7 +60,7 @@ def get_client_project_id(client: weave_client.WeaveClient) -> str: def test_simple_op(client): - @weave.op() + @weave.op def my_op(a: int) -> int: return a + 1 @@ -229,7 +229,7 @@ def test_call_read_not_found(client): def test_graph_call_ordering(client): - @weave.op() + @weave.op def my_op(a: int) -> int: return a + 1 @@ -263,27 +263,27 @@ def simple_line_call_bootstrap(init_wandb: bool = False) -> OpCallSpec: class Number(weave.Object): value: int - @weave.op() + @weave.op def adder(a: Number) -> Number: return Number(value=a.value + a.value) adder_v0 = adder - @weave.op() + @weave.op # type: ignore def adder(a: Number, b) -> Number: return Number(value=a.value + b) - @weave.op() + @weave.op def subtractor(a: Number, b) -> Number: return Number(value=a.value - b) - @weave.op() + @weave.op def multiplier( a: Number, b ) -> int: # intentionally deviant in returning plain int - so that we have a different type return a.value * b - @weave.op() + @weave.op def liner(m: Number, b, x) -> Number: return adder(Number(value=multiplier(m, x)), b) @@ -691,7 +691,7 @@ def test_trace_call_query_offset(client): def test_trace_call_sort(client): - @weave.op() + @weave.op def basic_op(in_val: dict, delay) -> dict: import time @@ -727,7 +727,7 @@ def test_trace_call_sort_with_mixed_types(client): # SQLite does not support sorting over mixed types in a column, so we skip this test return - @weave.op() + @weave.op def basic_op(in_val: dict) -> dict: import time @@ -769,7 +769,7 @@ def basic_op(in_val: dict) -> dict: def test_trace_call_filter(client): is_sqlite = client_is_sqlite(client) - @weave.op() + @weave.op def basic_op(in_val: dict, delay) -> dict: return in_val @@ -1160,7 +1160,7 @@ def basic_op(in_val: dict, delay) -> dict: def test_ops_with_default_params(client): - @weave.op() + @weave.op def op_with_default(a: int, b: int = 10) -> int: return a + b @@ -1234,7 +1234,7 @@ class BaseTypeC(BaseTypeB): def test_attributes_on_ops(client): - @weave.op() + @weave.op def op_with_attrs(a: int, b: int) -> int: return a + b @@ -1277,7 +1277,7 @@ def test_dataclass_support(client): class MyDataclass: val: int - @weave.op() + @weave.op def dataclass_maker(a: MyDataclass, b: MyDataclass) -> MyDataclass: return MyDataclass(a.val + b.val) @@ -1322,7 +1322,7 @@ def dataclass_maker(a: MyDataclass, b: MyDataclass) -> MyDataclass: def test_op_retrieval(client): - @weave.op() + @weave.op def my_op(a: int) -> int: return a + 1 @@ -1336,7 +1336,7 @@ def test_bound_op_retrieval(client): class CustomType(weave.Object): a: int - @weave.op() + @weave.op def op_with_custom_type(self, v): return self.a + v @@ -1359,7 +1359,7 @@ def test_bound_op_retrieval_no_self(client): class CustomTypeWithoutSelf(weave.Object): a: int - @weave.op() + @weave.op def op_with_custom_type(me, v): return me.a + v @@ -1387,7 +1387,7 @@ def test_dataset_row_ref(client): def test_tuple_support(client): - @weave.op() + @weave.op def tuple_maker(a, b): return (a, b) @@ -1411,7 +1411,7 @@ def tuple_maker(a, b): def test_namedtuple_support(client): - @weave.op() + @weave.op def tuple_maker(a, b): return (a, b) @@ -1442,7 +1442,7 @@ def test_named_reuse(client): d_ref = weave.publish(d, "test_dataset") dataset = weave.ref(d_ref.uri()).get() - @weave.op() + @weave.op async def dummy_score(output): return 1 @@ -1489,7 +1489,7 @@ class MyUnknownClassB: def __init__(self, b_val) -> None: self.b_val = b_val - @weave.op() + @weave.op def op_with_unknown_types(a: MyUnknownClassA, b: float) -> MyUnknownClassB: return MyUnknownClassB(a.a_val + b) @@ -1564,19 +1564,19 @@ def init_weave_get_server_patched(api_key): def test_single_primitive_output(client): - @weave.op() + @weave.op def single_int_output(a: int) -> int: return a - @weave.op() + @weave.op def single_bool_output(a: int) -> bool: return a == 1 - @weave.op() + @weave.op def single_none_output(a: int) -> None: return None - @weave.op() + @weave.op def dict_output(a: int, b: bool, c: None) -> dict: return {"a": a, "b": b, "c": c} @@ -1669,14 +1669,14 @@ def test_mapped_execution(client, mapper): events = [] - @weave.op() + @weave.op def op_a(a: int) -> int: events.append("A(S):" + str(a)) time.sleep(0.3) events.append("A(E):" + str(a)) return a - @weave.op() + @weave.op def op_b(b: int) -> int: events.append("B(S):" + str(b)) time.sleep(0.2) @@ -1684,7 +1684,7 @@ def op_b(b: int) -> int: events.append("B(E):" + str(b)) return res - @weave.op() + @weave.op def op_c(c: int) -> int: events.append("C(S):" + str(c)) time.sleep(0.1) @@ -1692,7 +1692,7 @@ def op_c(c: int) -> int: events.append("C(E):" + str(c)) return res - @weave.op() + @weave.op def op_mapper(vals): return mapper(op_c, vals) @@ -2127,7 +2127,7 @@ def calculate(a: int, b: int) -> int: def test_call_query_stream_columns(client): @weave.op - def calculate(a: int, b: int) -> int: + def calculate(a: int, b: int) -> dict[str, Any]: return {"result": {"a + b": a + b}, "not result": 123} for i in range(2): @@ -2170,7 +2170,7 @@ def test_call_query_stream_columns_with_costs(client): return @weave.op - def calculate(a: int, b: int) -> int: + def calculate(a: int, b: int) -> dict[str, Any]: return { "result": {"a + b": a + b}, "not result": 123, @@ -2238,7 +2238,7 @@ def calculate(a: int, b: int) -> int: @pytest.mark.skip("Not implemented: filter / sort through refs") def test_sort_and_filter_through_refs(client): - @weave.op() + @weave.op def test_op(label, val): return val @@ -2272,7 +2272,8 @@ def test_obj(val): # Ref at A, B and C test_op( - values[7], {"a": test_obj({"b": test_obj({"c": test_obj({"d": values[7]})})})} + values[7], + {"a": test_obj({"b": test_obj({"c": test_obj({"d": values[7]})})})}, ) for first, last, sort_by in [ @@ -2355,7 +2356,7 @@ def test_obj(val): def test_in_operation(client): - @weave.op() + @weave.op def test_op(label, val): return val @@ -2500,7 +2501,7 @@ def func(x): class BasicModel(weave.Model): - @weave.op() + @weave.op def predict(self, x): return {"answer": "42"} @@ -2546,7 +2547,7 @@ class SimpleObject(weave.Object): class NestedObject(weave.Object): b: SimpleObject - @weave.op() + @weave.op def return_nested_object(nested_obj: NestedObject): return nested_obj @@ -2997,3 +2998,224 @@ def foo(): foo() assert len(list(weave_client.get_calls())) == 1 assert weave.trace.weave_init._current_inited_client is None + + +def test_op_sampling(client): + never_traced_calls = 0 + always_traced_calls = 0 + sometimes_traced_calls = 0 + + @weave.op(tracing_sample_rate=0.0) + def never_traced(x: int) -> int: + nonlocal never_traced_calls + never_traced_calls += 1 + return x + 1 + + @weave.op(tracing_sample_rate=1.0) + def always_traced(x: int) -> int: + nonlocal always_traced_calls + always_traced_calls += 1 + return x + 1 + + @weave.op(tracing_sample_rate=0.5) + def sometimes_traced(x: int) -> int: + nonlocal sometimes_traced_calls + sometimes_traced_calls += 1 + return x + 1 + + weave.publish(never_traced) + # Never traced should execute but not be traced + for i in range(10): + never_traced(i) + assert never_traced_calls == 10 # Function was called + assert len(list(never_traced.calls())) == 0 # Not traced + + # Always traced should execute and be traced + for i in range(10): + always_traced(i) + assert always_traced_calls == 10 # Function was called + assert len(list(always_traced.calls())) == 10 # And traced + # Sanity check that the call_start was logged, unlike in the never_traced case. + assert "call_start" in client.server.attribute_access_log + + # Sometimes traced should execute always but only be traced sometimes + num_runs = 100 + for i in range(num_runs): + sometimes_traced(i) + assert sometimes_traced_calls == num_runs # Function was called every time + num_traces = len(list(sometimes_traced.calls())) + assert 35 < num_traces < 65 # But only traced ~50% of the time + + +def test_op_sampling_async(client): + never_traced_calls = 0 + always_traced_calls = 0 + sometimes_traced_calls = 0 + + @weave.op(tracing_sample_rate=0.0) + async def never_traced(x: int) -> int: + nonlocal never_traced_calls + never_traced_calls += 1 + return x + 1 + + @weave.op(tracing_sample_rate=1.0) + async def always_traced(x: int) -> int: + nonlocal always_traced_calls + always_traced_calls += 1 + return x + 1 + + @weave.op(tracing_sample_rate=0.5) + async def sometimes_traced(x: int) -> int: + nonlocal sometimes_traced_calls + sometimes_traced_calls += 1 + return x + 1 + + import asyncio + + weave.publish(never_traced) + # Never traced should execute but not be traced + for i in range(10): + asyncio.run(never_traced(i)) + assert never_traced_calls == 10 # Function was called + assert len(list(never_traced.calls())) == 0 # Not traced + + # Always traced should execute and be traced + for i in range(10): + asyncio.run(always_traced(i)) + assert always_traced_calls == 10 # Function was called + assert len(list(always_traced.calls())) == 10 # And traced + assert "call_start" in client.server.attribute_access_log + + # Sometimes traced should execute always but only be traced sometimes + num_runs = 100 + for i in range(num_runs): + asyncio.run(sometimes_traced(i)) + assert sometimes_traced_calls == num_runs # Function was called every time + num_traces = len(list(sometimes_traced.calls())) + assert 35 < num_traces < 65 # But only traced ~50% of the time + + +def test_op_sampling_inheritance(client): + parent_calls = 0 + child_calls = 0 + + @weave.op + def child_op(x: int) -> int: + nonlocal child_calls + child_calls += 1 + return x + 1 + + @weave.op(tracing_sample_rate=0.0) + def parent_op(x: int) -> int: + nonlocal parent_calls + parent_calls += 1 + return child_op(x) + + weave.publish(parent_op) + # When parent is sampled out, child should still execute but not be traced + for i in range(10): + parent_op(i) + + assert parent_calls == 10 # Parent function executed + assert child_calls == 10 # Child function executed + assert len(list(parent_op.calls())) == 0 # Parent not traced + + # Reset counters + child_calls = 0 + + # Direct calls to child should execute and be traced + for i in range(10): + child_op(i) + + assert child_calls == 10 # Child function executed + assert len(list(child_op.calls())) == 10 # And was traced + assert "call_start" in client.server.attribute_access_log # Verify tracing occurred + + +def test_op_sampling_inheritance_async(client): + parent_calls = 0 + child_calls = 0 + + @weave.op + async def child_op(x: int) -> int: + nonlocal child_calls + child_calls += 1 + return x + 1 + + @weave.op(tracing_sample_rate=0.0) + async def parent_op(x: int) -> int: + nonlocal parent_calls + parent_calls += 1 + return await child_op(x) + + import asyncio + + weave.publish(parent_op) + # When parent is sampled out, child should still execute but not be traced + for i in range(10): + asyncio.run(parent_op(i)) + + assert parent_calls == 10 # Parent function executed + assert child_calls == 10 # Child function executed + assert len(list(parent_op.calls())) == 0 # Parent not traced + + # Reset counters + child_calls = 0 + + # Direct calls to child should execute and be traced + for i in range(10): + asyncio.run(child_op(i)) + + assert child_calls == 10 # Child function executed + assert len(list(child_op.calls())) == 10 # And was traced + assert "call_start" in client.server.attribute_access_log # Verify tracing occurred + + +def test_op_sampling_invalid_rates(client): + with pytest.raises(ValueError): + + @weave.op(tracing_sample_rate=-0.5) + def negative_rate(): + pass + + with pytest.raises(ValueError): + + @weave.op(tracing_sample_rate=1.5) + def too_high_rate(): + pass + + with pytest.raises(TypeError): + + @weave.op(tracing_sample_rate="invalid") # type: ignore + def invalid_type(): + pass + + +def test_op_sampling_child_follows_parent(client): + parent_calls = 0 + child_calls = 0 + + @weave.op(tracing_sample_rate=0.0) # Never traced + def child_op(x: int) -> int: + nonlocal child_calls + child_calls += 1 + return x + 1 + + @weave.op(tracing_sample_rate=1.0) # Always traced + def parent_op(x: int) -> int: + nonlocal parent_calls + parent_calls += 1 + return child_op(x) + + num_runs = 100 + for i in range(num_runs): + parent_op(i) + + assert parent_calls == num_runs # Parent was always executed + assert child_calls == num_runs # Child was always executed + + parent_traces = len(list(parent_op.calls())) + child_traces = len(list(child_op.calls())) + + assert parent_traces == num_runs # Parent was always traced + assert child_traces == num_runs # Child was traced whenever parent was diff --git a/weave/trace/context/call_context.py b/weave/trace/context/call_context.py index 402e1843ade..3a03bd167c3 100644 --- a/weave/trace/context/call_context.py +++ b/weave/trace/context/call_context.py @@ -20,6 +20,8 @@ class NoCurrentCallError(Exception): ... logger = logging.getLogger(__name__) +_tracing_enabled = contextvars.ContextVar("tracing_enabled", default=True) + def push_call(call: Call) -> None: new_stack = copy.copy(_call_stack.get()) @@ -136,3 +138,22 @@ def set_call_stack(stack: list[Call]) -> Iterator[list[Call]]: call_attributes: contextvars.ContextVar[dict[str, Any]] = contextvars.ContextVar( "call_attributes", default={} ) + + +def get_tracing_enabled() -> bool: + return _tracing_enabled.get() + + +@contextlib.contextmanager +def set_tracing_enabled(enabled: bool) -> Iterator[None]: + token = _tracing_enabled.set(enabled) + try: + yield + finally: + _tracing_enabled.reset(token) + + +@contextlib.contextmanager +def tracing_disabled() -> Iterator[None]: + with set_tracing_enabled(False): + yield diff --git a/weave/trace/op.py b/weave/trace/op.py index 2b5835474d8..a89c7400d8b 100644 --- a/weave/trace/op.py +++ b/weave/trace/op.py @@ -4,6 +4,7 @@ import inspect import logging +import random import sys import traceback from collections.abc import Coroutine, Mapping @@ -26,7 +27,11 @@ from weave.trace.constants import TRACE_CALL_EMOJI from weave.trace.context import call_context from weave.trace.context import weave_client_context as weave_client_context -from weave.trace.context.call_context import call_attributes +from weave.trace.context.call_context import ( + call_attributes, + get_tracing_enabled, + tracing_disabled, +) from weave.trace.context.tests_context import get_raise_on_captured_errors from weave.trace.errors import OpCallError from weave.trace.refs import ObjectRef @@ -174,6 +179,8 @@ class Op(Protocol): # it disables child ops as well. _tracing_enabled: bool + tracing_sample_rate: float + def _set_on_input_handler(func: Op, on_input: OnInputHandlerType) -> None: if func._on_input_handler is not None: @@ -407,37 +414,54 @@ def _do_call( if not pargs: pargs = _default_on_input_handler(op, args, kwargs) + # Handle all of the possible cases where we would skip tracing. if settings.should_disable_weave(): res = func(*pargs.args, **pargs.kwargs) - elif weave_client_context.get_weave_client() is None: + return res, call + if weave_client_context.get_weave_client() is None: + res = func(*pargs.args, **pargs.kwargs) + return res, call + if not op._tracing_enabled: + res = func(*pargs.args, **pargs.kwargs) + return res, call + if not get_tracing_enabled(): res = func(*pargs.args, **pargs.kwargs) - elif not op._tracing_enabled: + return res, call + + current_call = call_context.get_current_call() + if current_call is None: + # Root call: decide whether to trace based on sample rate + if random.random() > op.tracing_sample_rate: + # Disable tracing for this call and all descendants + with tracing_disabled(): + res = func(*pargs.args, **pargs.kwargs) + return res, call + + # Proceed with tracing. Note that we don't check the sample rate here. + # Only root calls get sampling applied. + # If the parent was traced (sampled in), the child will be too. + try: + call = _create_call(op, *args, __weave=__weave, **kwargs) + except OpCallError as e: + raise e + except Exception as e: + if get_raise_on_captured_errors(): + raise + log_once( + logger.error, + CALL_CREATE_MSG.format(traceback.format_exc()), + ) res = func(*pargs.args, **pargs.kwargs) else: - try: - # This try/except allows us to fail gracefully and - # still let the user code continue to execute - call = _create_call(op, *args, __weave=__weave, **kwargs) - except OpCallError as e: - raise e - except Exception as e: - if get_raise_on_captured_errors(): - raise - log_once( - logger.error, - CALL_CREATE_MSG.format(traceback.format_exc()), - ) - res = func(*pargs.args, **pargs.kwargs) - else: - execute_result = _execute_op( - op, call, *pargs.args, __should_raise=__should_raise, **pargs.kwargs + execute_result = _execute_op( + op, call, *pargs.args, __should_raise=__should_raise, **pargs.kwargs + ) + if inspect.iscoroutine(execute_result): + raise TypeError( + "Internal error: Expected `_execute_call` to return a sync result" ) - if inspect.iscoroutine(execute_result): - raise TypeError( - "Internal error: Expected `_execute_call` to return a sync result" - ) - execute_result = cast(tuple[Any, "Call"], execute_result) - res, call = execute_result + execute_result = cast(tuple[Any, "Call"], execute_result) + res, call = execute_result return res, call @@ -450,39 +474,52 @@ async def _do_call_async( ) -> tuple[Any, Call]: func = op.resolve_fn call = _placeholder_call() + + # Handle all of the possible cases where we would skip tracing. if settings.should_disable_weave(): res = await func(*args, **kwargs) - elif weave_client_context.get_weave_client() is None: + return res, call + if weave_client_context.get_weave_client() is None: res = await func(*args, **kwargs) - elif not op._tracing_enabled: + return res, call + if not op._tracing_enabled: + res = await func(*args, **kwargs) + return res, call + if not get_tracing_enabled(): + res = await func(*args, **kwargs) + return res, call + + current_call = call_context.get_current_call() + if current_call is None: + # Root call: decide whether to trace based on sample rate + if random.random() > op.tracing_sample_rate: + # Disable tracing for this call and all descendants + with tracing_disabled(): + res = await func(*args, **kwargs) + return res, call + + # Proceed with tracing + try: + call = _create_call(op, *args, __weave=__weave, **kwargs) + except OpCallError as e: + raise e + except Exception as e: + if get_raise_on_captured_errors(): + raise + log_once( + logger.error, + ASYNC_CALL_CREATE_MSG.format(traceback.format_exc()), + ) res = await func(*args, **kwargs) else: - try: - # This try/except allows us to fail gracefully and - # still let the user code continue to execute - call = _create_call(op, *args, __weave=__weave, **kwargs) - except OpCallError as e: - raise e - except Exception as e: - if get_raise_on_captured_errors(): - raise - log_once( - logger.error, - ASYNC_CALL_CREATE_MSG.format(traceback.format_exc()), - ) - res = await func(*args, **kwargs) - else: - execute_result = _execute_op( - op, call, *args, __should_raise=__should_raise, **kwargs - ) - if not inspect.iscoroutine(execute_result): - raise TypeError( - "Internal error: Expected `_execute_call` to return a coroutine" - ) - execute_result = cast( - Coroutine[Any, Any, tuple[Any, "Call"]], execute_result + execute_result = _execute_op( + op, call, *args, __should_raise=__should_raise, **kwargs + ) + if not inspect.iscoroutine(execute_result): + raise TypeError( + "Internal error: Expected `_execute_call` to return a coroutine" ) - res, call = await execute_result + res, call = await execute_result return res, call @@ -540,6 +577,7 @@ def op( call_display_name: str | CallDisplayNameFunc | None = None, postprocess_inputs: PostprocessInputsFunc | None = None, postprocess_output: PostprocessOutputFunc | None = None, + tracing_sample_rate: float = 1.0, ) -> Callable[[Callable], Op] | Op: """ A decorator to weave op-ify a function or method. Works for both sync and async. @@ -565,6 +603,7 @@ def op( postprocess_output (Optional[Callable[..., Any]]): A function to process the output after it's been returned from the function but before it's logged. This does not affect the actual output of the function, only the displayed output. + tracing_sample_rate (float): The sampling rate for tracing this function. Defaults to 1.0 (always trace). Returns: Union[Callable[[Any], Op], Op]: If called without arguments, returns a decorator. @@ -591,6 +630,10 @@ async def extract(): await extract() # calls the function and tracks the call in the Weave UI ``` """ + if not isinstance(tracing_sample_rate, (int, float)): + raise TypeError("tracing_sample_rate must be a float") + if not 0 <= tracing_sample_rate <= 1: + raise ValueError("tracing_sample_rate must be between 0 and 1") def op_deco(func: Callable) -> Op: # Check function type @@ -647,6 +690,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: wrapper._on_finish_handler = None # type: ignore wrapper._tracing_enabled = True # type: ignore + wrapper.tracing_sample_rate = tracing_sample_rate # type: ignore wrapper.get_captured_code = partial(get_captured_code, wrapper) # type: ignore