Skip to content

Commit

Permalink
Make tests simpler and update docs
Browse files Browse the repository at this point in the history
  • Loading branch information
adrnswanberg committed Dec 13, 2024
1 parent 70b2c26 commit dfffb78
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 7 deletions.
2 changes: 2 additions & 0 deletions docs/docs/guides/tracking/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ A Weave op is a versioned function that automatically logs all calls.
<TabItem value="python" label="Python" default>
You can control how frequently an op's calls are traced by setting the `tracing_sample_rate` parameter in the `@weave.op` decorator. This is useful for high-frequency ops where you only need to trace a subset of calls.

Note that sampling rates are only applied to root calls. If an op has a sample rate, but is called by another op first, then that sampling rate will be ignored.

```python
@weave.op(tracing_sample_rate=0.1) # Only trace ~10% of calls
def high_frequency_op(x: int) -> int:
Expand Down
44 changes: 38 additions & 6 deletions tests/trace/test_client_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -3023,13 +3023,12 @@ def sometimes_traced(x: int) -> int:
sometimes_traced_calls += 1
return x + 1

weave.publish(never_traced)
# Never traced should execute but not be traced
for i in range(10):
never_traced(i)
assert never_traced_calls == 10 # Function was called
# NOTE: We can't assert here that never_traced.calls() is empty because that call requires
# the op to be published. If we never trace, we never publish the op.
assert "call_start" not in client.server.attribute_access_log
assert len(list(never_traced.calls())) == 0 # Not traced

# Always traced should execute and be traced
for i in range(10):
Expand Down Expand Up @@ -3073,11 +3072,12 @@ async def sometimes_traced(x: int) -> int:

import asyncio

weave.publish(never_traced)
# Never traced should execute but not be traced
for i in range(10):
asyncio.run(never_traced(i))
assert never_traced_calls == 10 # Function was called
assert "call_start" not in client.server.attribute_access_log
assert len(list(never_traced.calls())) == 0 # Not traced

# Always traced should execute and be traced
for i in range(10):
Expand Down Expand Up @@ -3111,13 +3111,14 @@ def parent_op(x: int) -> int:
parent_calls += 1
return child_op(x)

weave.publish(parent_op)
# When parent is sampled out, child should still execute but not be traced
for i in range(10):
parent_op(i)

assert parent_calls == 10 # Parent function executed
assert child_calls == 10 # Child function executed
assert "call_start" not in client.server.attribute_access_log
assert len(list(parent_op.calls())) == 0 # Parent not traced

# Reset counters
child_calls = 0
Expand Down Expand Up @@ -3149,13 +3150,14 @@ async def parent_op(x: int) -> int:

import asyncio

weave.publish(parent_op)
# When parent is sampled out, child should still execute but not be traced
for i in range(10):
asyncio.run(parent_op(i))

assert parent_calls == 10 # Parent function executed
assert child_calls == 10 # Child function executed
assert "call_start" not in client.server.attribute_access_log
assert len(list(parent_op.calls())) == 0 # Parent not traced

# Reset counters
child_calls = 0
Expand Down Expand Up @@ -3187,3 +3189,33 @@ def too_high_rate():
@weave.op(tracing_sample_rate="invalid") # type: ignore
def invalid_type():
pass


def test_op_sampling_child_follows_parent(client):
parent_calls = 0
child_calls = 0

@weave.op(tracing_sample_rate=0.0) # Never traced
def child_op(x: int) -> int:
nonlocal child_calls
child_calls += 1
return x + 1

@weave.op(tracing_sample_rate=1.0) # Always traced
def parent_op(x: int) -> int:
nonlocal parent_calls
parent_calls += 1
return child_op(x)

num_runs = 100
for i in range(num_runs):
parent_op(i)

assert parent_calls == num_runs # Parent was always executed
assert child_calls == num_runs # Child was always executed

parent_traces = len(list(parent_op.calls()))
child_traces = len(list(child_op.calls()))

assert parent_traces == num_runs # Parent was always traced
assert child_traces == num_runs # Child was traced whenever parent was
4 changes: 3 additions & 1 deletion weave/trace/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,9 @@ def _do_call(
res = func(*pargs.args, **pargs.kwargs)
return res, call

# Proceed with tracing
# Proceed with tracing. Note that we don't check the sample rate here.
# Only root calls get sampling applied.
# If the parent was traced (sampled in), the child will be too.
try:
call = _create_call(op, *args, __weave=__weave, **kwargs)
except OpCallError as e:
Expand Down

0 comments on commit dfffb78

Please sign in to comment.