-
Notifications
You must be signed in to change notification settings - Fork 67
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(weave): Add 'tracing_sample_rate' param to weave.op #3195
Changes from 4 commits
0830889
05bc825
8880ee2
ce39f49
cf553a2
2294620
7ce644f
0cd0fc2
a720f05
e5dc643
df7db03
ff01b4b
70b2c26
dfffb78
0a406a1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -269,7 +269,7 @@ def adder(a: Number) -> Number: | |
|
||
adder_v0 = adder | ||
|
||
@weave.op() | ||
@weave.op() # type: ignore | ||
def adder(a: Number, b) -> Number: | ||
return Number(value=a.value + b) | ||
|
||
|
@@ -2127,7 +2127,7 @@ def calculate(a: int, b: int) -> int: | |
|
||
def test_call_query_stream_columns(client): | ||
@weave.op | ||
def calculate(a: int, b: int) -> int: | ||
def calculate(a: int, b: int) -> dict[str, Any]: | ||
return {"result": {"a + b": a + b}, "not result": 123} | ||
|
||
for i in range(2): | ||
|
@@ -2170,7 +2170,7 @@ def test_call_query_stream_columns_with_costs(client): | |
return | ||
|
||
@weave.op | ||
def calculate(a: int, b: int) -> int: | ||
def calculate(a: int, b: int) -> dict[str, Any]: | ||
return { | ||
"result": {"a + b": a + b}, | ||
"not result": 123, | ||
|
@@ -2272,7 +2272,8 @@ def test_obj(val): | |
|
||
# Ref at A, B and C | ||
test_op( | ||
values[7], {"a": test_obj({"b": test_obj({"c": test_obj({"d": values[7]})})})} | ||
values[7], | ||
{"a": test_obj({"b": test_obj({"c": test_obj({"d": values[7]})})})}, | ||
) | ||
|
||
for first, last, sort_by in [ | ||
|
@@ -2997,3 +2998,192 @@ def foo(): | |
foo() | ||
assert len(list(weave_client.get_calls())) == 1 | ||
assert weave.trace.weave_init._current_inited_client is None | ||
|
||
|
||
def test_op_sampling(client): | ||
never_traced_calls = 0 | ||
always_traced_calls = 0 | ||
sometimes_traced_calls = 0 | ||
|
||
@weave.op(tracing_sample_rate=0.0) | ||
def never_traced(x: int) -> int: | ||
nonlocal never_traced_calls | ||
never_traced_calls += 1 | ||
return x + 1 | ||
|
||
@weave.op(tracing_sample_rate=1.0) | ||
def always_traced(x: int) -> int: | ||
nonlocal always_traced_calls | ||
always_traced_calls += 1 | ||
return x + 1 | ||
|
||
@weave.op(tracing_sample_rate=0.5) | ||
def sometimes_traced(x: int) -> int: | ||
nonlocal sometimes_traced_calls | ||
sometimes_traced_calls += 1 | ||
return x + 1 | ||
|
||
# Never traced should execute but not be traced | ||
for i in range(10): | ||
never_traced(i) | ||
assert never_traced_calls == 10 # Function was called | ||
# NOTE: We can't assert here that never_traced.calls() is empty because that call requires | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you could just publish the op There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you mean just in this test? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. well it would be nice to have a consistent assertion for # of calls. so anywhere you did not to .calls could have explicit publishing |
||
# the op to be published. If we never trace, we never publish the op. | ||
assert "call_start" not in client.server.attribute_access_log | ||
|
||
# Always traced should execute and be traced | ||
for i in range(10): | ||
always_traced(i) | ||
assert always_traced_calls == 10 # Function was called | ||
assert len(list(always_traced.calls())) == 10 # And traced | ||
# Sanity check that the call_start was logged, unlike in the never_traced case. | ||
assert "call_start" in client.server.attribute_access_log | ||
|
||
# Sometimes traced should execute always but only be traced sometimes | ||
num_runs = 100 | ||
for i in range(num_runs): | ||
sometimes_traced(i) | ||
assert sometimes_traced_calls == num_runs # Function was called every time | ||
num_traces = len(list(sometimes_traced.calls())) | ||
assert 40 < num_traces < 60 # But only traced ~50% of the time | ||
|
||
|
||
def test_op_sampling_async(client): | ||
never_traced_calls = 0 | ||
always_traced_calls = 0 | ||
sometimes_traced_calls = 0 | ||
|
||
@weave.op(tracing_sample_rate=0.0) | ||
async def never_traced(x: int) -> int: | ||
nonlocal never_traced_calls | ||
never_traced_calls += 1 | ||
return x + 1 | ||
|
||
@weave.op(tracing_sample_rate=1.0) | ||
async def always_traced(x: int) -> int: | ||
nonlocal always_traced_calls | ||
always_traced_calls += 1 | ||
return x + 1 | ||
|
||
@weave.op(tracing_sample_rate=0.5) | ||
async def sometimes_traced(x: int) -> int: | ||
nonlocal sometimes_traced_calls | ||
sometimes_traced_calls += 1 | ||
return x + 1 | ||
|
||
import asyncio | ||
|
||
# Never traced should execute but not be traced | ||
for i in range(10): | ||
asyncio.run(never_traced(i)) | ||
assert never_traced_calls == 10 # Function was called | ||
assert "call_start" not in client.server.attribute_access_log | ||
|
||
# Always traced should execute and be traced | ||
for i in range(10): | ||
asyncio.run(always_traced(i)) | ||
assert always_traced_calls == 10 # Function was called | ||
assert len(list(always_traced.calls())) == 10 # And traced | ||
assert "call_start" in client.server.attribute_access_log | ||
|
||
# Sometimes traced should execute always but only be traced sometimes | ||
num_runs = 100 | ||
for i in range(num_runs): | ||
asyncio.run(sometimes_traced(i)) | ||
assert sometimes_traced_calls == num_runs # Function was called every time | ||
num_traces = len(list(sometimes_traced.calls())) | ||
assert 40 < num_traces < 60 # But only traced ~50% of the time | ||
|
||
|
||
def test_op_sampling_inheritance(client): | ||
parent_calls = 0 | ||
child_calls = 0 | ||
|
||
@weave.op() | ||
def child_op(x: int) -> int: | ||
nonlocal child_calls | ||
child_calls += 1 | ||
return x + 1 | ||
|
||
@weave.op(tracing_sample_rate=0.0) | ||
def parent_op(x: int) -> int: | ||
nonlocal parent_calls | ||
parent_calls += 1 | ||
return child_op(x) | ||
|
||
# When parent is sampled out, child should still execute but not be traced | ||
for i in range(10): | ||
parent_op(i) | ||
|
||
assert parent_calls == 10 # Parent function executed | ||
assert child_calls == 10 # Child function executed | ||
assert "call_start" not in client.server.attribute_access_log | ||
|
||
# Reset counters | ||
child_calls = 0 | ||
|
||
# Direct calls to child should execute and be traced | ||
for i in range(10): | ||
child_op(i) | ||
|
||
assert child_calls == 10 # Child function executed | ||
assert len(list(child_op.calls())) == 10 # And was traced | ||
assert "call_start" in client.server.attribute_access_log # Verify tracing occurred | ||
|
||
|
||
def test_op_sampling_inheritance_async(client): | ||
parent_calls = 0 | ||
child_calls = 0 | ||
|
||
@weave.op() | ||
async def child_op(x: int) -> int: | ||
nonlocal child_calls | ||
child_calls += 1 | ||
return x + 1 | ||
|
||
@weave.op(tracing_sample_rate=0.0) | ||
async def parent_op(x: int) -> int: | ||
nonlocal parent_calls | ||
parent_calls += 1 | ||
return await child_op(x) | ||
|
||
import asyncio | ||
|
||
# When parent is sampled out, child should still execute but not be traced | ||
for i in range(10): | ||
asyncio.run(parent_op(i)) | ||
|
||
assert parent_calls == 10 # Parent function executed | ||
assert child_calls == 10 # Child function executed | ||
assert "call_start" not in client.server.attribute_access_log | ||
|
||
# Reset counters | ||
child_calls = 0 | ||
|
||
# Direct calls to child should execute and be traced | ||
for i in range(10): | ||
asyncio.run(child_op(i)) | ||
|
||
assert child_calls == 10 # Child function executed | ||
assert len(list(child_op.calls())) == 10 # And was traced | ||
assert "call_start" in client.server.attribute_access_log # Verify tracing occurred | ||
|
||
|
||
def test_op_sampling_invalid_rates(client): | ||
with pytest.raises(ValueError): | ||
|
||
@weave.op(tracing_sample_rate=-0.5) | ||
def negative_rate(): | ||
pass | ||
|
||
with pytest.raises(ValueError): | ||
|
||
@weave.op(tracing_sample_rate=1.5) | ||
def too_high_rate(): | ||
pass | ||
|
||
with pytest.raises(ValueError): | ||
|
||
@weave.op(tracing_sample_rate="invalid") # type: ignore | ||
def invalid_type(): | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't use paren