From dfffb7848d9941394db14b764ee1d04fee681c25 Mon Sep 17 00:00:00 2001 From: Adrian Swanberg Date: Thu, 12 Dec 2024 21:40:16 -0800 Subject: [PATCH] Make tests simpler and update docs --- docs/docs/guides/tracking/ops.md | 2 ++ tests/trace/test_client_trace.py | 44 +++++++++++++++++++++++++++----- weave/trace/op.py | 4 ++- 3 files changed, 43 insertions(+), 7 deletions(-) diff --git a/docs/docs/guides/tracking/ops.md b/docs/docs/guides/tracking/ops.md index 1fb259d1b37..4c1e064b0aa 100644 --- a/docs/docs/guides/tracking/ops.md +++ b/docs/docs/guides/tracking/ops.md @@ -122,6 +122,8 @@ A Weave op is a versioned function that automatically logs all calls. You can control how frequently an op's calls are traced by setting the `tracing_sample_rate` parameter in the `@weave.op` decorator. This is useful for high-frequency ops where you only need to trace a subset of calls. + Note that sampling rates are only applied to root calls. If an op has a sample rate, but is called by another op first, then that sampling rate will be ignored. + ```python @weave.op(tracing_sample_rate=0.1) # Only trace ~10% of calls def high_frequency_op(x: int) -> int: diff --git a/tests/trace/test_client_trace.py b/tests/trace/test_client_trace.py index 9d1fb4bb223..005c79f5cb0 100644 --- a/tests/trace/test_client_trace.py +++ b/tests/trace/test_client_trace.py @@ -3023,13 +3023,12 @@ def sometimes_traced(x: int) -> int: sometimes_traced_calls += 1 return x + 1 + weave.publish(never_traced) # Never traced should execute but not be traced for i in range(10): never_traced(i) assert never_traced_calls == 10 # Function was called - # NOTE: We can't assert here that never_traced.calls() is empty because that call requires - # the op to be published. If we never trace, we never publish the op. - assert "call_start" not in client.server.attribute_access_log + assert len(list(never_traced.calls())) == 0 # Not traced # Always traced should execute and be traced for i in range(10): @@ -3073,11 +3072,12 @@ async def sometimes_traced(x: int) -> int: import asyncio + weave.publish(never_traced) # Never traced should execute but not be traced for i in range(10): asyncio.run(never_traced(i)) assert never_traced_calls == 10 # Function was called - assert "call_start" not in client.server.attribute_access_log + assert len(list(never_traced.calls())) == 0 # Not traced # Always traced should execute and be traced for i in range(10): @@ -3111,13 +3111,14 @@ def parent_op(x: int) -> int: parent_calls += 1 return child_op(x) + weave.publish(parent_op) # When parent is sampled out, child should still execute but not be traced for i in range(10): parent_op(i) assert parent_calls == 10 # Parent function executed assert child_calls == 10 # Child function executed - assert "call_start" not in client.server.attribute_access_log + assert len(list(parent_op.calls())) == 0 # Parent not traced # Reset counters child_calls = 0 @@ -3149,13 +3150,14 @@ async def parent_op(x: int) -> int: import asyncio + weave.publish(parent_op) # When parent is sampled out, child should still execute but not be traced for i in range(10): asyncio.run(parent_op(i)) assert parent_calls == 10 # Parent function executed assert child_calls == 10 # Child function executed - assert "call_start" not in client.server.attribute_access_log + assert len(list(parent_op.calls())) == 0 # Parent not traced # Reset counters child_calls = 0 @@ -3187,3 +3189,33 @@ def too_high_rate(): @weave.op(tracing_sample_rate="invalid") # type: ignore def invalid_type(): pass + + +def test_op_sampling_child_follows_parent(client): + parent_calls = 0 + child_calls = 0 + + @weave.op(tracing_sample_rate=0.0) # Never traced + def child_op(x: int) -> int: + nonlocal child_calls + child_calls += 1 + return x + 1 + + @weave.op(tracing_sample_rate=1.0) # Always traced + def parent_op(x: int) -> int: + nonlocal parent_calls + parent_calls += 1 + return child_op(x) + + num_runs = 100 + for i in range(num_runs): + parent_op(i) + + assert parent_calls == num_runs # Parent was always executed + assert child_calls == num_runs # Child was always executed + + parent_traces = len(list(parent_op.calls())) + child_traces = len(list(child_op.calls())) + + assert parent_traces == num_runs # Parent was always traced + assert child_traces == num_runs # Child was traced whenever parent was diff --git a/weave/trace/op.py b/weave/trace/op.py index 37a6dc1c160..a89c7400d8b 100644 --- a/weave/trace/op.py +++ b/weave/trace/op.py @@ -437,7 +437,9 @@ def _do_call( res = func(*pargs.args, **pargs.kwargs) return res, call - # Proceed with tracing + # Proceed with tracing. Note that we don't check the sample rate here. + # Only root calls get sampling applied. + # If the parent was traced (sampled in), the child will be too. try: call = _create_call(op, *args, __weave=__weave, **kwargs) except OpCallError as e: