diff --git a/docs/docs/guides/tracking/ops.md b/docs/docs/guides/tracking/ops.md index b69d5d1d91a..4c1e064b0aa 100644 --- a/docs/docs/guides/tracking/ops.md +++ b/docs/docs/guides/tracking/ops.md @@ -116,6 +116,39 @@ A Weave op is a versioned function that automatically logs all calls. +## Control sampling rate + + + + You can control how frequently an op's calls are traced by setting the `tracing_sample_rate` parameter in the `@weave.op` decorator. This is useful for high-frequency ops where you only need to trace a subset of calls. + + Note that sampling rates are only applied to root calls. If an op has a sample rate, but is called by another op first, then that sampling rate will be ignored. + + ```python + @weave.op(tracing_sample_rate=0.1) # Only trace ~10% of calls + def high_frequency_op(x: int) -> int: + return x + 1 + + @weave.op(tracing_sample_rate=1.0) # Always trace (default) + def always_traced_op(x: int) -> int: + return x + 1 + ``` + + When an op's call is not sampled: + - The function executes normally + - No trace data is sent to Weave + - Child ops are also not traced for that call + + The sampling rate must be between 0.0 and 1.0 inclusive. + + + + ```plaintext + This feature is not available in TypeScript yet. Stay tuned! + ``` + + + ### Control call link output If you want to suppress the printing of call links during logging, you can use the `WEAVE_PRINT_CALL_LINK` environment variable to `false`. This can be useful if you want to reduce output verbosity and reduce clutter in your logs. diff --git a/tests/trace/test_client_trace.py b/tests/trace/test_client_trace.py index 5b45abc8432..005c79f5cb0 100644 --- a/tests/trace/test_client_trace.py +++ b/tests/trace/test_client_trace.py @@ -60,7 +60,7 @@ def get_client_project_id(client: weave_client.WeaveClient) -> str: def test_simple_op(client): - @weave.op() + @weave.op def my_op(a: int) -> int: return a + 1 @@ -229,7 +229,7 @@ def test_call_read_not_found(client): def test_graph_call_ordering(client): - @weave.op() + @weave.op def my_op(a: int) -> int: return a + 1 @@ -263,27 +263,27 @@ def simple_line_call_bootstrap(init_wandb: bool = False) -> OpCallSpec: class Number(weave.Object): value: int - @weave.op() + @weave.op def adder(a: Number) -> Number: return Number(value=a.value + a.value) adder_v0 = adder - @weave.op() + @weave.op # type: ignore def adder(a: Number, b) -> Number: return Number(value=a.value + b) - @weave.op() + @weave.op def subtractor(a: Number, b) -> Number: return Number(value=a.value - b) - @weave.op() + @weave.op def multiplier( a: Number, b ) -> int: # intentionally deviant in returning plain int - so that we have a different type return a.value * b - @weave.op() + @weave.op def liner(m: Number, b, x) -> Number: return adder(Number(value=multiplier(m, x)), b) @@ -691,7 +691,7 @@ def test_trace_call_query_offset(client): def test_trace_call_sort(client): - @weave.op() + @weave.op def basic_op(in_val: dict, delay) -> dict: import time @@ -727,7 +727,7 @@ def test_trace_call_sort_with_mixed_types(client): # SQLite does not support sorting over mixed types in a column, so we skip this test return - @weave.op() + @weave.op def basic_op(in_val: dict) -> dict: import time @@ -769,7 +769,7 @@ def basic_op(in_val: dict) -> dict: def test_trace_call_filter(client): is_sqlite = client_is_sqlite(client) - @weave.op() + @weave.op def basic_op(in_val: dict, delay) -> dict: return in_val @@ -1160,7 +1160,7 @@ def basic_op(in_val: dict, delay) -> dict: def test_ops_with_default_params(client): - @weave.op() + @weave.op def op_with_default(a: int, b: int = 10) -> int: return a + b @@ -1234,7 +1234,7 @@ class BaseTypeC(BaseTypeB): def test_attributes_on_ops(client): - @weave.op() + @weave.op def op_with_attrs(a: int, b: int) -> int: return a + b @@ -1277,7 +1277,7 @@ def test_dataclass_support(client): class MyDataclass: val: int - @weave.op() + @weave.op def dataclass_maker(a: MyDataclass, b: MyDataclass) -> MyDataclass: return MyDataclass(a.val + b.val) @@ -1322,7 +1322,7 @@ def dataclass_maker(a: MyDataclass, b: MyDataclass) -> MyDataclass: def test_op_retrieval(client): - @weave.op() + @weave.op def my_op(a: int) -> int: return a + 1 @@ -1336,7 +1336,7 @@ def test_bound_op_retrieval(client): class CustomType(weave.Object): a: int - @weave.op() + @weave.op def op_with_custom_type(self, v): return self.a + v @@ -1359,7 +1359,7 @@ def test_bound_op_retrieval_no_self(client): class CustomTypeWithoutSelf(weave.Object): a: int - @weave.op() + @weave.op def op_with_custom_type(me, v): return me.a + v @@ -1387,7 +1387,7 @@ def test_dataset_row_ref(client): def test_tuple_support(client): - @weave.op() + @weave.op def tuple_maker(a, b): return (a, b) @@ -1411,7 +1411,7 @@ def tuple_maker(a, b): def test_namedtuple_support(client): - @weave.op() + @weave.op def tuple_maker(a, b): return (a, b) @@ -1442,7 +1442,7 @@ def test_named_reuse(client): d_ref = weave.publish(d, "test_dataset") dataset = weave.ref(d_ref.uri()).get() - @weave.op() + @weave.op async def dummy_score(output): return 1 @@ -1489,7 +1489,7 @@ class MyUnknownClassB: def __init__(self, b_val) -> None: self.b_val = b_val - @weave.op() + @weave.op def op_with_unknown_types(a: MyUnknownClassA, b: float) -> MyUnknownClassB: return MyUnknownClassB(a.a_val + b) @@ -1564,19 +1564,19 @@ def init_weave_get_server_patched(api_key): def test_single_primitive_output(client): - @weave.op() + @weave.op def single_int_output(a: int) -> int: return a - @weave.op() + @weave.op def single_bool_output(a: int) -> bool: return a == 1 - @weave.op() + @weave.op def single_none_output(a: int) -> None: return None - @weave.op() + @weave.op def dict_output(a: int, b: bool, c: None) -> dict: return {"a": a, "b": b, "c": c} @@ -1669,14 +1669,14 @@ def test_mapped_execution(client, mapper): events = [] - @weave.op() + @weave.op def op_a(a: int) -> int: events.append("A(S):" + str(a)) time.sleep(0.3) events.append("A(E):" + str(a)) return a - @weave.op() + @weave.op def op_b(b: int) -> int: events.append("B(S):" + str(b)) time.sleep(0.2) @@ -1684,7 +1684,7 @@ def op_b(b: int) -> int: events.append("B(E):" + str(b)) return res - @weave.op() + @weave.op def op_c(c: int) -> int: events.append("C(S):" + str(c)) time.sleep(0.1) @@ -1692,7 +1692,7 @@ def op_c(c: int) -> int: events.append("C(E):" + str(c)) return res - @weave.op() + @weave.op def op_mapper(vals): return mapper(op_c, vals) @@ -2127,7 +2127,7 @@ def calculate(a: int, b: int) -> int: def test_call_query_stream_columns(client): @weave.op - def calculate(a: int, b: int) -> int: + def calculate(a: int, b: int) -> dict[str, Any]: return {"result": {"a + b": a + b}, "not result": 123} for i in range(2): @@ -2170,7 +2170,7 @@ def test_call_query_stream_columns_with_costs(client): return @weave.op - def calculate(a: int, b: int) -> int: + def calculate(a: int, b: int) -> dict[str, Any]: return { "result": {"a + b": a + b}, "not result": 123, @@ -2238,7 +2238,7 @@ def calculate(a: int, b: int) -> int: @pytest.mark.skip("Not implemented: filter / sort through refs") def test_sort_and_filter_through_refs(client): - @weave.op() + @weave.op def test_op(label, val): return val @@ -2272,7 +2272,8 @@ def test_obj(val): # Ref at A, B and C test_op( - values[7], {"a": test_obj({"b": test_obj({"c": test_obj({"d": values[7]})})})} + values[7], + {"a": test_obj({"b": test_obj({"c": test_obj({"d": values[7]})})})}, ) for first, last, sort_by in [ @@ -2355,7 +2356,7 @@ def test_obj(val): def test_in_operation(client): - @weave.op() + @weave.op def test_op(label, val): return val @@ -2500,7 +2501,7 @@ def func(x): class BasicModel(weave.Model): - @weave.op() + @weave.op def predict(self, x): return {"answer": "42"} @@ -2546,7 +2547,7 @@ class SimpleObject(weave.Object): class NestedObject(weave.Object): b: SimpleObject - @weave.op() + @weave.op def return_nested_object(nested_obj: NestedObject): return nested_obj @@ -2997,3 +2998,224 @@ def foo(): foo() assert len(list(weave_client.get_calls())) == 1 assert weave.trace.weave_init._current_inited_client is None + + +def test_op_sampling(client): + never_traced_calls = 0 + always_traced_calls = 0 + sometimes_traced_calls = 0 + + @weave.op(tracing_sample_rate=0.0) + def never_traced(x: int) -> int: + nonlocal never_traced_calls + never_traced_calls += 1 + return x + 1 + + @weave.op(tracing_sample_rate=1.0) + def always_traced(x: int) -> int: + nonlocal always_traced_calls + always_traced_calls += 1 + return x + 1 + + @weave.op(tracing_sample_rate=0.5) + def sometimes_traced(x: int) -> int: + nonlocal sometimes_traced_calls + sometimes_traced_calls += 1 + return x + 1 + + weave.publish(never_traced) + # Never traced should execute but not be traced + for i in range(10): + never_traced(i) + assert never_traced_calls == 10 # Function was called + assert len(list(never_traced.calls())) == 0 # Not traced + + # Always traced should execute and be traced + for i in range(10): + always_traced(i) + assert always_traced_calls == 10 # Function was called + assert len(list(always_traced.calls())) == 10 # And traced + # Sanity check that the call_start was logged, unlike in the never_traced case. + assert "call_start" in client.server.attribute_access_log + + # Sometimes traced should execute always but only be traced sometimes + num_runs = 100 + for i in range(num_runs): + sometimes_traced(i) + assert sometimes_traced_calls == num_runs # Function was called every time + num_traces = len(list(sometimes_traced.calls())) + assert 35 < num_traces < 65 # But only traced ~50% of the time + + +def test_op_sampling_async(client): + never_traced_calls = 0 + always_traced_calls = 0 + sometimes_traced_calls = 0 + + @weave.op(tracing_sample_rate=0.0) + async def never_traced(x: int) -> int: + nonlocal never_traced_calls + never_traced_calls += 1 + return x + 1 + + @weave.op(tracing_sample_rate=1.0) + async def always_traced(x: int) -> int: + nonlocal always_traced_calls + always_traced_calls += 1 + return x + 1 + + @weave.op(tracing_sample_rate=0.5) + async def sometimes_traced(x: int) -> int: + nonlocal sometimes_traced_calls + sometimes_traced_calls += 1 + return x + 1 + + import asyncio + + weave.publish(never_traced) + # Never traced should execute but not be traced + for i in range(10): + asyncio.run(never_traced(i)) + assert never_traced_calls == 10 # Function was called + assert len(list(never_traced.calls())) == 0 # Not traced + + # Always traced should execute and be traced + for i in range(10): + asyncio.run(always_traced(i)) + assert always_traced_calls == 10 # Function was called + assert len(list(always_traced.calls())) == 10 # And traced + assert "call_start" in client.server.attribute_access_log + + # Sometimes traced should execute always but only be traced sometimes + num_runs = 100 + for i in range(num_runs): + asyncio.run(sometimes_traced(i)) + assert sometimes_traced_calls == num_runs # Function was called every time + num_traces = len(list(sometimes_traced.calls())) + assert 35 < num_traces < 65 # But only traced ~50% of the time + + +def test_op_sampling_inheritance(client): + parent_calls = 0 + child_calls = 0 + + @weave.op + def child_op(x: int) -> int: + nonlocal child_calls + child_calls += 1 + return x + 1 + + @weave.op(tracing_sample_rate=0.0) + def parent_op(x: int) -> int: + nonlocal parent_calls + parent_calls += 1 + return child_op(x) + + weave.publish(parent_op) + # When parent is sampled out, child should still execute but not be traced + for i in range(10): + parent_op(i) + + assert parent_calls == 10 # Parent function executed + assert child_calls == 10 # Child function executed + assert len(list(parent_op.calls())) == 0 # Parent not traced + + # Reset counters + child_calls = 0 + + # Direct calls to child should execute and be traced + for i in range(10): + child_op(i) + + assert child_calls == 10 # Child function executed + assert len(list(child_op.calls())) == 10 # And was traced + assert "call_start" in client.server.attribute_access_log # Verify tracing occurred + + +def test_op_sampling_inheritance_async(client): + parent_calls = 0 + child_calls = 0 + + @weave.op + async def child_op(x: int) -> int: + nonlocal child_calls + child_calls += 1 + return x + 1 + + @weave.op(tracing_sample_rate=0.0) + async def parent_op(x: int) -> int: + nonlocal parent_calls + parent_calls += 1 + return await child_op(x) + + import asyncio + + weave.publish(parent_op) + # When parent is sampled out, child should still execute but not be traced + for i in range(10): + asyncio.run(parent_op(i)) + + assert parent_calls == 10 # Parent function executed + assert child_calls == 10 # Child function executed + assert len(list(parent_op.calls())) == 0 # Parent not traced + + # Reset counters + child_calls = 0 + + # Direct calls to child should execute and be traced + for i in range(10): + asyncio.run(child_op(i)) + + assert child_calls == 10 # Child function executed + assert len(list(child_op.calls())) == 10 # And was traced + assert "call_start" in client.server.attribute_access_log # Verify tracing occurred + + +def test_op_sampling_invalid_rates(client): + with pytest.raises(ValueError): + + @weave.op(tracing_sample_rate=-0.5) + def negative_rate(): + pass + + with pytest.raises(ValueError): + + @weave.op(tracing_sample_rate=1.5) + def too_high_rate(): + pass + + with pytest.raises(TypeError): + + @weave.op(tracing_sample_rate="invalid") # type: ignore + def invalid_type(): + pass + + +def test_op_sampling_child_follows_parent(client): + parent_calls = 0 + child_calls = 0 + + @weave.op(tracing_sample_rate=0.0) # Never traced + def child_op(x: int) -> int: + nonlocal child_calls + child_calls += 1 + return x + 1 + + @weave.op(tracing_sample_rate=1.0) # Always traced + def parent_op(x: int) -> int: + nonlocal parent_calls + parent_calls += 1 + return child_op(x) + + num_runs = 100 + for i in range(num_runs): + parent_op(i) + + assert parent_calls == num_runs # Parent was always executed + assert child_calls == num_runs # Child was always executed + + parent_traces = len(list(parent_op.calls())) + child_traces = len(list(child_op.calls())) + + assert parent_traces == num_runs # Parent was always traced + assert child_traces == num_runs # Child was traced whenever parent was diff --git a/tests/trace/test_evaluations.py b/tests/trace/test_evaluations.py index e5c38ef0140..ab74d4c0c0b 100644 --- a/tests/trace/test_evaluations.py +++ b/tests/trace/test_evaluations.py @@ -1021,13 +1021,19 @@ def my_second_scorer(text, output, model_output): ds = [{"text": "hello"}] - with pytest.raises(ValueError, match="Both 'output' and 'model_output'"): + with pytest.raises( + ValueError, match="cannot include both `output` and `model_output`" + ): scorer = MyScorer() - with pytest.raises(ValueError, match="Both 'output' and 'model_output'"): + with pytest.raises( + ValueError, match="cannot include both `output` and `model_output`" + ): evaluation = weave.Evaluation(dataset=ds, scorers=[MyScorer()]) - with pytest.raises(ValueError, match="Both 'output' and 'model_output'"): + with pytest.raises( + ValueError, match="cannot include both `output` and `model_output`" + ): evaluation = weave.Evaluation(dataset=ds, scorers=[my_second_scorer]) diff --git a/weave-js/package.json b/weave-js/package.json index 1f551021ed4..db96be60691 100644 --- a/weave-js/package.json +++ b/weave-js/package.json @@ -141,6 +141,7 @@ "unified": "^10.1.0", "unist-util-visit": "3.1.0", "universal-perf-hooks": "^1.0.1", + "uuid": "^11.0.3", "vega": "^5.24.0", "vega-lite": "5.6.0", "vega-tooltip": "^0.28.0", @@ -236,7 +237,6 @@ "tslint-config-prettier": "^1.18.0", "tslint-plugin-prettier": "^2.3.0", "typescript": "4.7.4", - "uuid": "^11.0.3", "vite": "5.2.9", "vitest": "^1.6.0" }, diff --git a/weave-js/src/assets/icons/icon-spiral.svg b/weave-js/src/assets/icons/icon-spiral.svg new file mode 100644 index 00000000000..ce5c147b43a --- /dev/null +++ b/weave-js/src/assets/icons/icon-spiral.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/weave-js/src/components/Form/TextField.tsx b/weave-js/src/components/Form/TextField.tsx index c40f697ac85..8f5dd1171ff 100644 --- a/weave-js/src/components/Form/TextField.tsx +++ b/weave-js/src/components/Form/TextField.tsx @@ -37,6 +37,7 @@ type TextFieldProps = { dataTest?: string; step?: number; variant?: 'default' | 'ghost'; + isContainerNightAware?: boolean; }; export const TextField = ({ @@ -59,6 +60,7 @@ export const TextField = ({ autoComplete, dataTest, step, + isContainerNightAware, }: TextFieldProps) => { const textFieldSize = size ?? 'medium'; const leftPaddingForIcon = textFieldSize === 'medium' ? 'pl-34' : 'pl-36'; @@ -83,7 +85,6 @@ export const TextField = ({
( export const IconSortDescending = (props: SVGIconProps) => ( ); +export const IconSpiral = (props: SVGIconProps) => ( + +); export const IconSplit = (props: SVGIconProps) => ( ); @@ -1295,6 +1299,7 @@ const ICON_NAME_TO_ICON: Record = { sort: IconSort, 'sort-ascending': IconSortAscending, 'sort-descending': IconSortDescending, + spiral: IconSpiral, split: IconSplit, square: IconSquare, star: IconStar, diff --git a/weave-js/src/components/Icon/index.ts b/weave-js/src/components/Icon/index.ts index 08bf7854ad2..39c6eed3170 100644 --- a/weave-js/src/components/Icon/index.ts +++ b/weave-js/src/components/Icon/index.ts @@ -211,6 +211,7 @@ export { IconSort, IconSortAscending, IconSortDescending, + IconSpiral, IconSplit, IconSquare, IconStar, diff --git a/weave-js/src/components/Icon/types.ts b/weave-js/src/components/Icon/types.ts index 7ca30049257..47f5f357adc 100644 --- a/weave-js/src/components/Icon/types.ts +++ b/weave-js/src/components/Icon/types.ts @@ -210,6 +210,7 @@ export const IconNames = { Sort: 'sort', SortAscending: 'sort-ascending', SortDescending: 'sort-descending', + Spiral: 'spiral', Split: 'split', Square: 'square', Star: 'star', diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CustomGridTreeDataGroupingCell.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CustomGridTreeDataGroupingCell.tsx index 4db1b6334ad..05a71c7d4ab 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CustomGridTreeDataGroupingCell.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CustomGridTreeDataGroupingCell.tsx @@ -173,7 +173,7 @@ export const CustomGridTreeDataGroupingCell: FC< diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesView.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesView.tsx index 138ca10c7e8..135d297539d 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesView.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesView.tsx @@ -31,7 +31,11 @@ export const ChoicesView = ({ } if (choices.length === 1) { return ( - + ); } return ( diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/MessagePanel.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/MessagePanel.tsx index f570b2f6295..8cec95707fa 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/MessagePanel.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/MessagePanel.tsx @@ -47,6 +47,13 @@ export const MessagePanel = ({ } }, [message.content, contentRef?.current?.scrollHeight]); + // Set isShowingMore to true when editor is opened + useEffect(() => { + if (editorHeight !== null) { + setIsShowingMore(true); + } + }, [editorHeight]); + const isUser = message.role === 'user'; const isSystemPrompt = message.role === 'system'; const isTool = message.role === 'tool'; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/PlaygroundMessagePanelEditor.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/PlaygroundMessagePanelEditor.tsx index aa519c9659b..d643f103481 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/PlaygroundMessagePanelEditor.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/PlaygroundMessagePanelEditor.tsx @@ -68,13 +68,12 @@ export const PlaygroundMessagePanelEditor: React.FC<
setEditedContent(e.target.value)} - autoGrow - maxHeight={160} + startHeight={320} /> {/* 6px vs. 8px to make up for extra padding from textarea field */}
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/PlaygroundChat.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/PlaygroundChat.tsx index 38f69c4482e..b6b6e7c420d 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/PlaygroundChat.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/PlaygroundChat.tsx @@ -35,7 +35,6 @@ export const PlaygroundChat = ({ setSettingsTab, settingsTab, }: PlaygroundChatProps) => { - console.log('playgroundStates', playgroundStates); const [chatText, setChatText] = useState(''); const [isLoading, setIsLoading] = useState(false); diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/useChatFunctions.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/useChatFunctions.tsx index 0ce3ad02b51..804670a1dc3 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/useChatFunctions.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/useChatFunctions.tsx @@ -43,8 +43,6 @@ export const useChatFunctions = ( messageIndex: number, newMessage: Message ) => { - console.log('editMessage', callIndex, messageIndex, newMessage); - setPlaygroundStateField(callIndex, 'traceCall', prevTraceCall => { const newTraceCall = clearTraceCall( cloneDeep(prevTraceCall as OptionalTraceCallSchema) @@ -108,7 +106,6 @@ export const useChatFunctions = ( choiceIndex: number, newChoice: Message ) => { - console.log('editChoice', callIndex, choiceIndex, newChoice); setPlaygroundStateField(callIndex, 'traceCall', prevTraceCall => { const newTraceCall = clearTraceCall( cloneDeep(prevTraceCall as OptionalTraceCallSchema) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundPage.tsx index c6232631e4e..76d1c6d9e31 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundPage.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundPage.tsx @@ -7,7 +7,11 @@ import {SimplePageLayoutWithHeader} from '../common/SimplePageLayout'; import {useWFHooks} from '../wfReactInterface/context'; import {PlaygroundChat} from './PlaygroundChat/PlaygroundChat'; import {PlaygroundSettings} from './PlaygroundSettings/PlaygroundSettings'; -import {DEFAULT_SYSTEM_MESSAGE, usePlaygroundState} from './usePlaygroundState'; +import { + DEFAULT_SYSTEM_MESSAGE, + parseTraceCall, + usePlaygroundState, +} from './usePlaygroundState'; export type PlaygroundPageProps = { entity: string; @@ -89,7 +93,10 @@ export const PlaygroundPageInner = (props: PlaygroundPageProps) => { for (const [idx, state] of newStates.entries()) { for (const c of calls || []) { if (state.traceCall.id === c.callId) { - newStates[idx] = {...state, traceCall: c.traceCall || {}}; + newStates[idx] = { + ...state, + traceCall: parseTraceCall(c.traceCall || {}), + }; break; } } diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/StyledTextarea.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/StyledTextarea.tsx index 14a5b121da5..97d6aa7d39b 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/StyledTextarea.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/StyledTextarea.tsx @@ -8,11 +8,12 @@ import React, {forwardRef} from 'react'; type TextAreaProps = React.TextareaHTMLAttributes & { autoGrow?: boolean; maxHeight?: string | number; + startHeight?: string | number; reset?: boolean; }; export const StyledTextArea = forwardRef( - ({className, autoGrow, maxHeight, reset, ...props}, ref) => { + ({className, autoGrow, maxHeight, startHeight, reset, ...props}, ref) => { const textareaRef = React.useRef(null); React.useEffect(() => { @@ -26,11 +27,22 @@ export const StyledTextArea = forwardRef( return; } - // Disable resize when autoGrow is true - textareaElement.style.resize = 'none'; + // Only disable resize when autoGrow is true + textareaElement.style.resize = autoGrow ? 'none' : 'vertical'; + + // Set initial height if provided + if (startHeight && textareaElement.value === '') { + textareaElement.style.height = + typeof startHeight === 'number' ? `${startHeight}px` : startHeight; + return; + } if (reset || textareaElement.value === '') { - textareaElement.style.height = 'auto'; + textareaElement.style.height = startHeight + ? typeof startHeight === 'number' + ? `${startHeight}px` + : startHeight + : 'auto'; return; } @@ -63,7 +75,7 @@ export const StyledTextArea = forwardRef( return () => textareaRefElement.removeEventListener('input', adjustHeight); - }, [autoGrow, maxHeight, reset]); + }, [autoGrow, maxHeight, reset, startHeight]); return ( @@ -86,6 +98,7 @@ export const StyledTextArea = forwardRef( 'focus:outline-none', 'relative bottom-0 top-0 items-center rounded-sm', 'outline outline-1 outline-moon-250', + !autoGrow && 'resize-y', props.disabled ? 'opacity-50' : 'hover:outline hover:outline-2 hover:outline-teal-500/40 focus:outline-2', @@ -94,6 +107,14 @@ export const StyledTextArea = forwardRef( 'placeholder-moon-500 dark:placeholder-moon-600', className )} + style={{ + height: startHeight + ? typeof startHeight === 'number' + ? `${startHeight}px` + : startHeight + : undefined, + ...props.style, + }} {...props} /> diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/usePlaygroundState.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/usePlaygroundState.ts index 8439abc4ddf..cbcc7c52fb8 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/usePlaygroundState.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/usePlaygroundState.ts @@ -1,5 +1,11 @@ +import {cloneDeep} from 'lodash'; import {SetStateAction, useCallback, useState} from 'react'; +import { + anthropicContentBlocksToChoices, + hasStringProp, + isAnthropicCompletionFormat, +} from '../ChatView/hooks'; import {LLM_MAX_TOKENS_KEYS, LLMMaxTokensKey} from './llmMaxTokens'; import { OptionalTraceCallSchema, @@ -77,7 +83,7 @@ export const usePlaygroundState = () => { setPlaygroundStates(prevState => { const newState = {...prevState[0]}; - newState.traceCall = traceCall; + newState.traceCall = parseTraceCall(traceCall); if (!inputs) { return [newState]; @@ -155,3 +161,35 @@ export const getInputFromPlaygroundState = (state: PlaygroundState) => { tools: tools.length > 0 ? tools : undefined, }; }; + +// This is a helper function to parse the trace call output for anthropic +// so that the playground can display the choices +export const parseTraceCall = (traceCall: OptionalTraceCallSchema) => { + const parsedTraceCall = cloneDeep(traceCall); + + // Handles anthropic outputs + // Anthropic has content and stop_reason as top-level fields + if (isAnthropicCompletionFormat(parsedTraceCall.output)) { + const {content, stop_reason, ...outputs} = parsedTraceCall.output as any; + parsedTraceCall.output = { + ...outputs, + choices: anthropicContentBlocksToChoices(content, stop_reason), + }; + } + // Handles anthropic inputs + // Anthropic has system message as a top-level request field + if (hasStringProp(parsedTraceCall.inputs, 'system')) { + const {messages, system, ...inputs} = parsedTraceCall.inputs as any; + parsedTraceCall.inputs = { + ...inputs, + messages: [ + { + role: 'system', + content: system, + }, + ...messages, + ], + }; + } + return parsedTraceCall; +}; diff --git a/weave/scorers/base_scorer.py b/weave/scorers/base_scorer.py index 5a19adcd04f..4ac27f1a76b 100644 --- a/weave/scorers/base_scorer.py +++ b/weave/scorers/base_scorer.py @@ -1,4 +1,5 @@ import inspect +import textwrap from collections.abc import Sequence from numbers import Number from typing import Any, Callable, Optional, Union @@ -45,7 +46,13 @@ def _validate_scorer_signature(scorer: Union[Callable, Op, Scorer]) -> bool: params = inspect.signature(scorer).parameters if "output" in params and "model_output" in params: raise ValueError( - "Both 'output' and 'model_output' cannot be in the scorer signature; prefer just using `output`." + textwrap.dedent( + """ + The scorer signature cannot include both `output` and `model_output` at the same time. + + To resolve, rename one of the arguments to avoid conflict. Prefer using `output` as the model's output. + """ + ) ) return True diff --git a/weave/trace/context/call_context.py b/weave/trace/context/call_context.py index 402e1843ade..3a03bd167c3 100644 --- a/weave/trace/context/call_context.py +++ b/weave/trace/context/call_context.py @@ -20,6 +20,8 @@ class NoCurrentCallError(Exception): ... logger = logging.getLogger(__name__) +_tracing_enabled = contextvars.ContextVar("tracing_enabled", default=True) + def push_call(call: Call) -> None: new_stack = copy.copy(_call_stack.get()) @@ -136,3 +138,22 @@ def set_call_stack(stack: list[Call]) -> Iterator[list[Call]]: call_attributes: contextvars.ContextVar[dict[str, Any]] = contextvars.ContextVar( "call_attributes", default={} ) + + +def get_tracing_enabled() -> bool: + return _tracing_enabled.get() + + +@contextlib.contextmanager +def set_tracing_enabled(enabled: bool) -> Iterator[None]: + token = _tracing_enabled.set(enabled) + try: + yield + finally: + _tracing_enabled.reset(token) + + +@contextlib.contextmanager +def tracing_disabled() -> Iterator[None]: + with set_tracing_enabled(False): + yield diff --git a/weave/trace/op.py b/weave/trace/op.py index 2b5835474d8..a89c7400d8b 100644 --- a/weave/trace/op.py +++ b/weave/trace/op.py @@ -4,6 +4,7 @@ import inspect import logging +import random import sys import traceback from collections.abc import Coroutine, Mapping @@ -26,7 +27,11 @@ from weave.trace.constants import TRACE_CALL_EMOJI from weave.trace.context import call_context from weave.trace.context import weave_client_context as weave_client_context -from weave.trace.context.call_context import call_attributes +from weave.trace.context.call_context import ( + call_attributes, + get_tracing_enabled, + tracing_disabled, +) from weave.trace.context.tests_context import get_raise_on_captured_errors from weave.trace.errors import OpCallError from weave.trace.refs import ObjectRef @@ -174,6 +179,8 @@ class Op(Protocol): # it disables child ops as well. _tracing_enabled: bool + tracing_sample_rate: float + def _set_on_input_handler(func: Op, on_input: OnInputHandlerType) -> None: if func._on_input_handler is not None: @@ -407,37 +414,54 @@ def _do_call( if not pargs: pargs = _default_on_input_handler(op, args, kwargs) + # Handle all of the possible cases where we would skip tracing. if settings.should_disable_weave(): res = func(*pargs.args, **pargs.kwargs) - elif weave_client_context.get_weave_client() is None: + return res, call + if weave_client_context.get_weave_client() is None: + res = func(*pargs.args, **pargs.kwargs) + return res, call + if not op._tracing_enabled: + res = func(*pargs.args, **pargs.kwargs) + return res, call + if not get_tracing_enabled(): res = func(*pargs.args, **pargs.kwargs) - elif not op._tracing_enabled: + return res, call + + current_call = call_context.get_current_call() + if current_call is None: + # Root call: decide whether to trace based on sample rate + if random.random() > op.tracing_sample_rate: + # Disable tracing for this call and all descendants + with tracing_disabled(): + res = func(*pargs.args, **pargs.kwargs) + return res, call + + # Proceed with tracing. Note that we don't check the sample rate here. + # Only root calls get sampling applied. + # If the parent was traced (sampled in), the child will be too. + try: + call = _create_call(op, *args, __weave=__weave, **kwargs) + except OpCallError as e: + raise e + except Exception as e: + if get_raise_on_captured_errors(): + raise + log_once( + logger.error, + CALL_CREATE_MSG.format(traceback.format_exc()), + ) res = func(*pargs.args, **pargs.kwargs) else: - try: - # This try/except allows us to fail gracefully and - # still let the user code continue to execute - call = _create_call(op, *args, __weave=__weave, **kwargs) - except OpCallError as e: - raise e - except Exception as e: - if get_raise_on_captured_errors(): - raise - log_once( - logger.error, - CALL_CREATE_MSG.format(traceback.format_exc()), - ) - res = func(*pargs.args, **pargs.kwargs) - else: - execute_result = _execute_op( - op, call, *pargs.args, __should_raise=__should_raise, **pargs.kwargs + execute_result = _execute_op( + op, call, *pargs.args, __should_raise=__should_raise, **pargs.kwargs + ) + if inspect.iscoroutine(execute_result): + raise TypeError( + "Internal error: Expected `_execute_call` to return a sync result" ) - if inspect.iscoroutine(execute_result): - raise TypeError( - "Internal error: Expected `_execute_call` to return a sync result" - ) - execute_result = cast(tuple[Any, "Call"], execute_result) - res, call = execute_result + execute_result = cast(tuple[Any, "Call"], execute_result) + res, call = execute_result return res, call @@ -450,39 +474,52 @@ async def _do_call_async( ) -> tuple[Any, Call]: func = op.resolve_fn call = _placeholder_call() + + # Handle all of the possible cases where we would skip tracing. if settings.should_disable_weave(): res = await func(*args, **kwargs) - elif weave_client_context.get_weave_client() is None: + return res, call + if weave_client_context.get_weave_client() is None: res = await func(*args, **kwargs) - elif not op._tracing_enabled: + return res, call + if not op._tracing_enabled: + res = await func(*args, **kwargs) + return res, call + if not get_tracing_enabled(): + res = await func(*args, **kwargs) + return res, call + + current_call = call_context.get_current_call() + if current_call is None: + # Root call: decide whether to trace based on sample rate + if random.random() > op.tracing_sample_rate: + # Disable tracing for this call and all descendants + with tracing_disabled(): + res = await func(*args, **kwargs) + return res, call + + # Proceed with tracing + try: + call = _create_call(op, *args, __weave=__weave, **kwargs) + except OpCallError as e: + raise e + except Exception as e: + if get_raise_on_captured_errors(): + raise + log_once( + logger.error, + ASYNC_CALL_CREATE_MSG.format(traceback.format_exc()), + ) res = await func(*args, **kwargs) else: - try: - # This try/except allows us to fail gracefully and - # still let the user code continue to execute - call = _create_call(op, *args, __weave=__weave, **kwargs) - except OpCallError as e: - raise e - except Exception as e: - if get_raise_on_captured_errors(): - raise - log_once( - logger.error, - ASYNC_CALL_CREATE_MSG.format(traceback.format_exc()), - ) - res = await func(*args, **kwargs) - else: - execute_result = _execute_op( - op, call, *args, __should_raise=__should_raise, **kwargs - ) - if not inspect.iscoroutine(execute_result): - raise TypeError( - "Internal error: Expected `_execute_call` to return a coroutine" - ) - execute_result = cast( - Coroutine[Any, Any, tuple[Any, "Call"]], execute_result + execute_result = _execute_op( + op, call, *args, __should_raise=__should_raise, **kwargs + ) + if not inspect.iscoroutine(execute_result): + raise TypeError( + "Internal error: Expected `_execute_call` to return a coroutine" ) - res, call = await execute_result + res, call = await execute_result return res, call @@ -540,6 +577,7 @@ def op( call_display_name: str | CallDisplayNameFunc | None = None, postprocess_inputs: PostprocessInputsFunc | None = None, postprocess_output: PostprocessOutputFunc | None = None, + tracing_sample_rate: float = 1.0, ) -> Callable[[Callable], Op] | Op: """ A decorator to weave op-ify a function or method. Works for both sync and async. @@ -565,6 +603,7 @@ def op( postprocess_output (Optional[Callable[..., Any]]): A function to process the output after it's been returned from the function but before it's logged. This does not affect the actual output of the function, only the displayed output. + tracing_sample_rate (float): The sampling rate for tracing this function. Defaults to 1.0 (always trace). Returns: Union[Callable[[Any], Op], Op]: If called without arguments, returns a decorator. @@ -591,6 +630,10 @@ async def extract(): await extract() # calls the function and tracks the call in the Weave UI ``` """ + if not isinstance(tracing_sample_rate, (int, float)): + raise TypeError("tracing_sample_rate must be a float") + if not 0 <= tracing_sample_rate <= 1: + raise ValueError("tracing_sample_rate must be between 0 and 1") def op_deco(func: Callable) -> Op: # Check function type @@ -647,6 +690,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: wrapper._on_finish_handler = None # type: ignore wrapper._tracing_enabled = True # type: ignore + wrapper.tracing_sample_rate = tracing_sample_rate # type: ignore wrapper.get_captured_code = partial(get_captured_code, wrapper) # type: ignore