Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(weave): Add 'tracing_sample_rate' param to weave.op #3195

Merged
merged 15 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 194 additions & 4 deletions tests/trace/test_client_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def adder(a: Number) -> Number:

adder_v0 = adder

@weave.op()
@weave.op() # type: ignore
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't use paren

def adder(a: Number, b) -> Number:
return Number(value=a.value + b)

Expand Down Expand Up @@ -2127,7 +2127,7 @@ def calculate(a: int, b: int) -> int:

def test_call_query_stream_columns(client):
@weave.op
def calculate(a: int, b: int) -> int:
def calculate(a: int, b: int) -> dict[str, Any]:
return {"result": {"a + b": a + b}, "not result": 123}

for i in range(2):
Expand Down Expand Up @@ -2170,7 +2170,7 @@ def test_call_query_stream_columns_with_costs(client):
return

@weave.op
def calculate(a: int, b: int) -> int:
def calculate(a: int, b: int) -> dict[str, Any]:
return {
"result": {"a + b": a + b},
"not result": 123,
Expand Down Expand Up @@ -2272,7 +2272,8 @@ def test_obj(val):

# Ref at A, B and C
test_op(
values[7], {"a": test_obj({"b": test_obj({"c": test_obj({"d": values[7]})})})}
values[7],
{"a": test_obj({"b": test_obj({"c": test_obj({"d": values[7]})})})},
)

for first, last, sort_by in [
Expand Down Expand Up @@ -2997,3 +2998,192 @@ def foo():
foo()
assert len(list(weave_client.get_calls())) == 1
assert weave.trace.weave_init._current_inited_client is None


def test_op_sampling(client):
never_traced_calls = 0
always_traced_calls = 0
sometimes_traced_calls = 0

@weave.op(tracing_sample_rate=0.0)
def never_traced(x: int) -> int:
nonlocal never_traced_calls
never_traced_calls += 1
return x + 1

@weave.op(tracing_sample_rate=1.0)
def always_traced(x: int) -> int:
nonlocal always_traced_calls
always_traced_calls += 1
return x + 1

@weave.op(tracing_sample_rate=0.5)
def sometimes_traced(x: int) -> int:
nonlocal sometimes_traced_calls
sometimes_traced_calls += 1
return x + 1

# Never traced should execute but not be traced
for i in range(10):
never_traced(i)
assert never_traced_calls == 10 # Function was called
# NOTE: We can't assert here that never_traced.calls() is empty because that call requires
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you could just publish the op

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you mean just in this test?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well it would be nice to have a consistent assertion for # of calls. so anywhere you did not to .calls could have explicit publishing

# the op to be published. If we never trace, we never publish the op.
assert "call_start" not in client.server.attribute_access_log

# Always traced should execute and be traced
for i in range(10):
always_traced(i)
assert always_traced_calls == 10 # Function was called
assert len(list(always_traced.calls())) == 10 # And traced
# Sanity check that the call_start was logged, unlike in the never_traced case.
assert "call_start" in client.server.attribute_access_log

# Sometimes traced should execute always but only be traced sometimes
num_runs = 100
for i in range(num_runs):
sometimes_traced(i)
assert sometimes_traced_calls == num_runs # Function was called every time
num_traces = len(list(sometimes_traced.calls()))
assert 40 < num_traces < 60 # But only traced ~50% of the time


def test_op_sampling_async(client):
never_traced_calls = 0
always_traced_calls = 0
sometimes_traced_calls = 0

@weave.op(tracing_sample_rate=0.0)
async def never_traced(x: int) -> int:
nonlocal never_traced_calls
never_traced_calls += 1
return x + 1

@weave.op(tracing_sample_rate=1.0)
async def always_traced(x: int) -> int:
nonlocal always_traced_calls
always_traced_calls += 1
return x + 1

@weave.op(tracing_sample_rate=0.5)
async def sometimes_traced(x: int) -> int:
nonlocal sometimes_traced_calls
sometimes_traced_calls += 1
return x + 1

import asyncio

# Never traced should execute but not be traced
for i in range(10):
asyncio.run(never_traced(i))
assert never_traced_calls == 10 # Function was called
assert "call_start" not in client.server.attribute_access_log

# Always traced should execute and be traced
for i in range(10):
asyncio.run(always_traced(i))
assert always_traced_calls == 10 # Function was called
assert len(list(always_traced.calls())) == 10 # And traced
assert "call_start" in client.server.attribute_access_log

# Sometimes traced should execute always but only be traced sometimes
num_runs = 100
for i in range(num_runs):
asyncio.run(sometimes_traced(i))
assert sometimes_traced_calls == num_runs # Function was called every time
num_traces = len(list(sometimes_traced.calls()))
assert 40 < num_traces < 60 # But only traced ~50% of the time


def test_op_sampling_inheritance(client):
parent_calls = 0
child_calls = 0

@weave.op()
def child_op(x: int) -> int:
nonlocal child_calls
child_calls += 1
return x + 1

@weave.op(tracing_sample_rate=0.0)
def parent_op(x: int) -> int:
nonlocal parent_calls
parent_calls += 1
return child_op(x)

# When parent is sampled out, child should still execute but not be traced
for i in range(10):
parent_op(i)

assert parent_calls == 10 # Parent function executed
assert child_calls == 10 # Child function executed
assert "call_start" not in client.server.attribute_access_log

# Reset counters
child_calls = 0

# Direct calls to child should execute and be traced
for i in range(10):
child_op(i)

assert child_calls == 10 # Child function executed
assert len(list(child_op.calls())) == 10 # And was traced
assert "call_start" in client.server.attribute_access_log # Verify tracing occurred


def test_op_sampling_inheritance_async(client):
parent_calls = 0
child_calls = 0

@weave.op()
async def child_op(x: int) -> int:
nonlocal child_calls
child_calls += 1
return x + 1

@weave.op(tracing_sample_rate=0.0)
async def parent_op(x: int) -> int:
nonlocal parent_calls
parent_calls += 1
return await child_op(x)

import asyncio

# When parent is sampled out, child should still execute but not be traced
for i in range(10):
asyncio.run(parent_op(i))

assert parent_calls == 10 # Parent function executed
assert child_calls == 10 # Child function executed
assert "call_start" not in client.server.attribute_access_log

# Reset counters
child_calls = 0

# Direct calls to child should execute and be traced
for i in range(10):
asyncio.run(child_op(i))

assert child_calls == 10 # Child function executed
assert len(list(child_op.calls())) == 10 # And was traced
assert "call_start" in client.server.attribute_access_log # Verify tracing occurred


def test_op_sampling_invalid_rates(client):
with pytest.raises(ValueError):

@weave.op(tracing_sample_rate=-0.5)
def negative_rate():
pass

with pytest.raises(ValueError):

@weave.op(tracing_sample_rate=1.5)
def too_high_rate():
pass

with pytest.raises(ValueError):

@weave.op(tracing_sample_rate="invalid") # type: ignore
def invalid_type():
pass
21 changes: 21 additions & 0 deletions weave/trace/context/call_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ class NoCurrentCallError(Exception): ...

logger = logging.getLogger(__name__)

_tracing_enabled = contextvars.ContextVar("tracing_enabled", default=True)


def push_call(call: Call) -> None:
new_stack = copy.copy(_call_stack.get())
Expand Down Expand Up @@ -136,3 +138,22 @@ def set_call_stack(stack: list[Call]) -> Iterator[list[Call]]:
call_attributes: contextvars.ContextVar[dict[str, Any]] = contextvars.ContextVar(
"call_attributes", default={}
)


def get_tracing_enabled() -> bool:
return _tracing_enabled.get()


@contextlib.contextmanager
def set_tracing_enabled(enabled: bool) -> Iterator[None]:
token = _tracing_enabled.set(enabled)
try:
yield
finally:
_tracing_enabled.reset(token)


@contextlib.contextmanager
def tracing_disabled() -> Iterator[None]:
with set_tracing_enabled(False):
yield
Loading
Loading