Skip to content

Commit

Permalink
feat(weave): Add 'tracing_sampling_rate' param to weave.op
Browse files Browse the repository at this point in the history
  • Loading branch information
adrnswanberg committed Dec 10, 2024
1 parent 0966c9f commit 0830889
Show file tree
Hide file tree
Showing 3 changed files with 191 additions and 28 deletions.
115 changes: 111 additions & 4 deletions tests/trace/test_client_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def adder(a: Number) -> Number:

adder_v0 = adder

@weave.op()
@weave.op() # type: ignore
def adder(a: Number, b) -> Number:
return Number(value=a.value + b)

Expand Down Expand Up @@ -2127,7 +2127,7 @@ def calculate(a: int, b: int) -> int:

def test_call_query_stream_columns(client):
@weave.op
def calculate(a: int, b: int) -> int:
def calculate(a: int, b: int) -> dict[str, Any]:
return {"result": {"a + b": a + b}, "not result": 123}

for i in range(2):
Expand Down Expand Up @@ -2170,7 +2170,7 @@ def test_call_query_stream_columns_with_costs(client):
return

@weave.op
def calculate(a: int, b: int) -> int:
def calculate(a: int, b: int) -> dict[str, Any]:
return {
"result": {"a + b": a + b},
"not result": 123,
Expand Down Expand Up @@ -2272,7 +2272,8 @@ def test_obj(val):

# Ref at A, B and C
test_op(
values[7], {"a": test_obj({"b": test_obj({"c": test_obj({"d": values[7]})})})}
values[7],
{"a": test_obj({"b": test_obj({"c": test_obj({"d": values[7]})})})},
)

for first, last, sort_by in [
Expand Down Expand Up @@ -2997,3 +2998,109 @@ def foo():
foo()
assert len(list(weave_client.get_calls())) == 1
assert weave.trace.weave_init._current_inited_client is None


def test_op_sampling(client):
never_traced_calls = 0
always_traced_calls = 0
sometimes_traced_calls = 0

@weave.op(tracing_sample_rate=0.0)
def never_traced(x: int) -> int:
nonlocal never_traced_calls
never_traced_calls += 1
return x + 1

@weave.op(tracing_sample_rate=1.0)
def always_traced(x: int) -> int:
nonlocal always_traced_calls
always_traced_calls += 1
return x + 1

@weave.op(tracing_sample_rate=0.5)
def sometimes_traced(x: int) -> int:
nonlocal sometimes_traced_calls
sometimes_traced_calls += 1
return x + 1

# Never traced should execute but not be traced
for i in range(10):
never_traced(i)
assert never_traced_calls == 10 # Function was called
# NOTE: We can't assert here that never_traced.calls() is empty because that call requires
# the op to be published. If we never trace, we never publish the op.
assert "call_start" not in client.server.attribute_access_log

# Always traced should execute and be traced
for i in range(10):
always_traced(i)
assert always_traced_calls == 10 # Function was called
assert len(list(always_traced.calls())) == 10 # And traced
# Sanity check that the call_start was logged, unlike in the never_traced case.
assert "call_start" in client.server.attribute_access_log

# Sometimes traced should execute always but only be traced sometimes
num_runs = 100
for i in range(num_runs):
sometimes_traced(i)
assert sometimes_traced_calls == num_runs # Function was called every time
num_traces = len(list(sometimes_traced.calls()))
assert 40 < num_traces < 60 # But only traced ~50% of the time


def test_op_sampling_inheritance(client):
parent_calls = 0
child_calls = 0

@weave.op()
def child_op(x: int) -> int:
nonlocal child_calls
child_calls += 1
return x + 1

@weave.op(tracing_sample_rate=0.0)
def parent_op(x: int) -> int:
nonlocal parent_calls
parent_calls += 1
return child_op(x)

# When parent is sampled out, child should still execute but not be traced
for i in range(10):
parent_op(i)

assert parent_calls == 10 # Parent function executed
assert child_calls == 10 # Child function executed
assert (
"call_start" not in client.server.attribute_access_log
) # But neither was traced

# Reset counters
child_calls = 0

# Direct calls to child should execute and be traced
for i in range(10):
child_op(i)

assert child_calls == 10 # Child function executed
assert len(list(child_op.calls())) == 10 # And was traced
assert "call_start" in client.server.attribute_access_log # Verify tracing occurred


def test_op_sampling_invalid_rates(client):
with pytest.raises(ValueError):

@weave.op(tracing_sample_rate=-0.5)
def negative_rate():
pass

with pytest.raises(ValueError):

@weave.op(tracing_sample_rate=1.5)
def too_high_rate():
pass

with pytest.raises(ValueError):

@weave.op(tracing_sample_rate="invalid") # type: ignore
def invalid_type():
pass
21 changes: 21 additions & 0 deletions weave/trace/context/call_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ class NoCurrentCallError(Exception): ...

logger = logging.getLogger(__name__)

_tracing_enabled = contextvars.ContextVar("tracing_enabled", default=True)


def push_call(call: Call) -> None:
new_stack = copy.copy(_call_stack.get())
Expand Down Expand Up @@ -136,3 +138,22 @@ def set_call_stack(stack: list[Call]) -> Iterator[list[Call]]:
call_attributes: contextvars.ContextVar[dict[str, Any]] = contextvars.ContextVar(
"call_attributes", default={}
)


def get_tracing_enabled() -> bool:
return _tracing_enabled.get()


@contextlib.contextmanager
def set_tracing_enabled(enabled: bool) -> Iterator[None]:
token = _tracing_enabled.set(enabled)
try:
yield
finally:
_tracing_enabled.reset(token)


@contextlib.contextmanager
def tracing_disabled() -> Iterator[None]:
with set_tracing_enabled(False):
yield
83 changes: 59 additions & 24 deletions weave/trace/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import inspect
import logging
import random
import sys
import traceback
from collections.abc import Coroutine, Mapping
Expand All @@ -26,7 +27,11 @@
from weave.trace.constants import TRACE_CALL_EMOJI
from weave.trace.context import call_context
from weave.trace.context import weave_client_context as weave_client_context
from weave.trace.context.call_context import call_attributes
from weave.trace.context.call_context import (
call_attributes,
get_tracing_enabled,
tracing_disabled,
)
from weave.trace.context.tests_context import get_raise_on_captured_errors
from weave.trace.errors import OpCallError
from weave.trace.refs import ObjectRef
Expand Down Expand Up @@ -174,6 +179,8 @@ class Op(Protocol):
# it disables child ops as well.
_tracing_enabled: bool

tracing_sample_rate: float


def _set_on_input_handler(func: Op, on_input: OnInputHandlerType) -> None:
if func._on_input_handler is not None:
Expand Down Expand Up @@ -401,22 +408,33 @@ def _do_call(
func = op.resolve_fn
call = _placeholder_call()

pargs = None
if op._on_input_handler is not None:
pargs = op._on_input_handler(op, args, kwargs)
if not pargs:
pargs = _default_on_input_handler(op, args, kwargs)
pargs = (
op._on_input_handler(op, args, kwargs)
if op._on_input_handler
else _default_on_input_handler(op, args, kwargs)
)

if settings.should_disable_weave():
res = func(*pargs.args, **pargs.kwargs)
elif weave_client_context.get_weave_client() is None:
res = func(*pargs.args, **pargs.kwargs)
elif not op._tracing_enabled:
skip_tracing = (
settings.should_disable_weave()
or weave_client_context.get_weave_client() is None
or not op._tracing_enabled
or not get_tracing_enabled()
)

if skip_tracing:
res = func(*pargs.args, **pargs.kwargs)
else:
current_call = call_context.get_current_call()
if current_call is None:
# Root call: decide whether to trace based on sample rate
if random.random() > op.tracing_sample_rate:
# Disable tracing for this call and all descendants
with tracing_disabled():
res = func(*pargs.args, **pargs.kwargs)
return res, call

# Proceed with tracing
try:
# This try/except allows us to fail gracefully and
# still let the user code continue to execute
call = _create_call(op, *args, __weave=__weave, **kwargs)
except OpCallError as e:
raise e
Expand All @@ -436,7 +454,6 @@ def _do_call(
raise TypeError(
"Internal error: Expected `_execute_call` to return a sync result"
)
execute_result = cast(tuple[Any, "Call"], execute_result)
res, call = execute_result
return res, call

Expand All @@ -450,16 +467,30 @@ async def _do_call_async(
) -> tuple[Any, Call]:
func = op.resolve_fn
call = _placeholder_call()
if settings.should_disable_weave():
res = await func(*args, **kwargs)
elif weave_client_context.get_weave_client() is None:
res = await func(*args, **kwargs)
elif not op._tracing_enabled:

if (
settings.should_disable_weave()
or weave_client_context.get_weave_client() is None
or not op._tracing_enabled
):
res = await func(*args, **kwargs)
else:
current_call = call_context.get_current_call()
tracing_enabled = get_tracing_enabled()
if current_call is None:
# Root call: decide whether to trace based on sample rate
if random.random() > op.tracing_sample_rate:
# Disable tracing for this call and all descendants
with tracing_disabled():
res = await func(*args, **kwargs)
return res, call
elif not tracing_enabled:
# Tracing is disabled in the context
res = await func(*args, **kwargs)
return res, call

# Proceed with tracing
try:
# This try/except allows us to fail gracefully and
# still let the user code continue to execute
call = _create_call(op, *args, __weave=__weave, **kwargs)
except OpCallError as e:
raise e
Expand All @@ -479,9 +510,6 @@ async def _do_call_async(
raise TypeError(
"Internal error: Expected `_execute_call` to return a coroutine"
)
execute_result = cast(
Coroutine[Any, Any, tuple[Any, "Call"]], execute_result
)
res, call = await execute_result
return res, call

Expand Down Expand Up @@ -540,6 +568,7 @@ def op(
call_display_name: str | CallDisplayNameFunc | None = None,
postprocess_inputs: PostprocessInputsFunc | None = None,
postprocess_output: PostprocessOutputFunc | None = None,
tracing_sample_rate: float = 1.0,
) -> Callable[[Callable], Op] | Op:
"""
A decorator to weave op-ify a function or method. Works for both sync and async.
Expand All @@ -565,6 +594,7 @@ def op(
postprocess_output (Optional[Callable[..., Any]]): A function to process the output
after it's been returned from the function but before it's logged. This does not
affect the actual output of the function, only the displayed output.
tracing_sample_rate (float): The sampling rate for tracing this function. Defaults to 1.0 (always trace).
Returns:
Union[Callable[[Any], Op], Op]: If called without arguments, returns a decorator.
Expand All @@ -591,6 +621,10 @@ async def extract():
await extract() # calls the function and tracks the call in the Weave UI
```
"""
if not isinstance(tracing_sample_rate, (int, float)):
raise ValueError("tracing_sample_rate must be a float")
if not 0 <= tracing_sample_rate <= 1:
raise ValueError("tracing_sample_rate must be between 0 and 1")

def op_deco(func: Callable) -> Op:
# Check function type
Expand Down Expand Up @@ -647,6 +681,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
wrapper._on_finish_handler = None # type: ignore

wrapper._tracing_enabled = True # type: ignore
wrapper.tracing_sample_rate = tracing_sample_rate # type: ignore

wrapper.get_captured_code = partial(get_captured_code, wrapper) # type: ignore

Expand Down

0 comments on commit 0830889

Please sign in to comment.