setEditedContent(e.target.value)}
- autoGrow
- maxHeight={160}
+ startHeight={320}
/>
{/* 6px vs. 8px to make up for extra padding from textarea field */}
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/PlaygroundChat.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/PlaygroundChat.tsx
index 38f69c4482e..b6b6e7c420d 100644
--- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/PlaygroundChat.tsx
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/PlaygroundChat.tsx
@@ -35,7 +35,6 @@ export const PlaygroundChat = ({
setSettingsTab,
settingsTab,
}: PlaygroundChatProps) => {
- console.log('playgroundStates', playgroundStates);
const [chatText, setChatText] = useState('');
const [isLoading, setIsLoading] = useState(false);
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/useChatFunctions.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/useChatFunctions.tsx
index 0ce3ad02b51..804670a1dc3 100644
--- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/useChatFunctions.tsx
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/useChatFunctions.tsx
@@ -43,8 +43,6 @@ export const useChatFunctions = (
messageIndex: number,
newMessage: Message
) => {
- console.log('editMessage', callIndex, messageIndex, newMessage);
-
setPlaygroundStateField(callIndex, 'traceCall', prevTraceCall => {
const newTraceCall = clearTraceCall(
cloneDeep(prevTraceCall as OptionalTraceCallSchema)
@@ -108,7 +106,6 @@ export const useChatFunctions = (
choiceIndex: number,
newChoice: Message
) => {
- console.log('editChoice', callIndex, choiceIndex, newChoice);
setPlaygroundStateField(callIndex, 'traceCall', prevTraceCall => {
const newTraceCall = clearTraceCall(
cloneDeep(prevTraceCall as OptionalTraceCallSchema)
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundPage.tsx
index c6232631e4e..76d1c6d9e31 100644
--- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundPage.tsx
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundPage.tsx
@@ -7,7 +7,11 @@ import {SimplePageLayoutWithHeader} from '../common/SimplePageLayout';
import {useWFHooks} from '../wfReactInterface/context';
import {PlaygroundChat} from './PlaygroundChat/PlaygroundChat';
import {PlaygroundSettings} from './PlaygroundSettings/PlaygroundSettings';
-import {DEFAULT_SYSTEM_MESSAGE, usePlaygroundState} from './usePlaygroundState';
+import {
+ DEFAULT_SYSTEM_MESSAGE,
+ parseTraceCall,
+ usePlaygroundState,
+} from './usePlaygroundState';
export type PlaygroundPageProps = {
entity: string;
@@ -89,7 +93,10 @@ export const PlaygroundPageInner = (props: PlaygroundPageProps) => {
for (const [idx, state] of newStates.entries()) {
for (const c of calls || []) {
if (state.traceCall.id === c.callId) {
- newStates[idx] = {...state, traceCall: c.traceCall || {}};
+ newStates[idx] = {
+ ...state,
+ traceCall: parseTraceCall(c.traceCall || {}),
+ };
break;
}
}
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/StyledTextarea.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/StyledTextarea.tsx
index 14a5b121da5..97d6aa7d39b 100644
--- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/StyledTextarea.tsx
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/StyledTextarea.tsx
@@ -8,11 +8,12 @@ import React, {forwardRef} from 'react';
type TextAreaProps = React.TextareaHTMLAttributes & {
autoGrow?: boolean;
maxHeight?: string | number;
+ startHeight?: string | number;
reset?: boolean;
};
export const StyledTextArea = forwardRef(
- ({className, autoGrow, maxHeight, reset, ...props}, ref) => {
+ ({className, autoGrow, maxHeight, startHeight, reset, ...props}, ref) => {
const textareaRef = React.useRef(null);
React.useEffect(() => {
@@ -26,11 +27,22 @@ export const StyledTextArea = forwardRef(
return;
}
- // Disable resize when autoGrow is true
- textareaElement.style.resize = 'none';
+ // Only disable resize when autoGrow is true
+ textareaElement.style.resize = autoGrow ? 'none' : 'vertical';
+
+ // Set initial height if provided
+ if (startHeight && textareaElement.value === '') {
+ textareaElement.style.height =
+ typeof startHeight === 'number' ? `${startHeight}px` : startHeight;
+ return;
+ }
if (reset || textareaElement.value === '') {
- textareaElement.style.height = 'auto';
+ textareaElement.style.height = startHeight
+ ? typeof startHeight === 'number'
+ ? `${startHeight}px`
+ : startHeight
+ : 'auto';
return;
}
@@ -63,7 +75,7 @@ export const StyledTextArea = forwardRef(
return () =>
textareaRefElement.removeEventListener('input', adjustHeight);
- }, [autoGrow, maxHeight, reset]);
+ }, [autoGrow, maxHeight, reset, startHeight]);
return (
@@ -86,6 +98,7 @@ export const StyledTextArea = forwardRef(
'focus:outline-none',
'relative bottom-0 top-0 items-center rounded-sm',
'outline outline-1 outline-moon-250',
+ !autoGrow && 'resize-y',
props.disabled
? 'opacity-50'
: 'hover:outline hover:outline-2 hover:outline-teal-500/40 focus:outline-2',
@@ -94,6 +107,14 @@ export const StyledTextArea = forwardRef(
'placeholder-moon-500 dark:placeholder-moon-600',
className
)}
+ style={{
+ height: startHeight
+ ? typeof startHeight === 'number'
+ ? `${startHeight}px`
+ : startHeight
+ : undefined,
+ ...props.style,
+ }}
{...props}
/>
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/usePlaygroundState.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/usePlaygroundState.ts
index 8439abc4ddf..cbcc7c52fb8 100644
--- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/usePlaygroundState.ts
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/usePlaygroundState.ts
@@ -1,5 +1,11 @@
+import {cloneDeep} from 'lodash';
import {SetStateAction, useCallback, useState} from 'react';
+import {
+ anthropicContentBlocksToChoices,
+ hasStringProp,
+ isAnthropicCompletionFormat,
+} from '../ChatView/hooks';
import {LLM_MAX_TOKENS_KEYS, LLMMaxTokensKey} from './llmMaxTokens';
import {
OptionalTraceCallSchema,
@@ -77,7 +83,7 @@ export const usePlaygroundState = () => {
setPlaygroundStates(prevState => {
const newState = {...prevState[0]};
- newState.traceCall = traceCall;
+ newState.traceCall = parseTraceCall(traceCall);
if (!inputs) {
return [newState];
@@ -155,3 +161,35 @@ export const getInputFromPlaygroundState = (state: PlaygroundState) => {
tools: tools.length > 0 ? tools : undefined,
};
};
+
+// This is a helper function to parse the trace call output for anthropic
+// so that the playground can display the choices
+export const parseTraceCall = (traceCall: OptionalTraceCallSchema) => {
+ const parsedTraceCall = cloneDeep(traceCall);
+
+ // Handles anthropic outputs
+ // Anthropic has content and stop_reason as top-level fields
+ if (isAnthropicCompletionFormat(parsedTraceCall.output)) {
+ const {content, stop_reason, ...outputs} = parsedTraceCall.output as any;
+ parsedTraceCall.output = {
+ ...outputs,
+ choices: anthropicContentBlocksToChoices(content, stop_reason),
+ };
+ }
+ // Handles anthropic inputs
+ // Anthropic has system message as a top-level request field
+ if (hasStringProp(parsedTraceCall.inputs, 'system')) {
+ const {messages, system, ...inputs} = parsedTraceCall.inputs as any;
+ parsedTraceCall.inputs = {
+ ...inputs,
+ messages: [
+ {
+ role: 'system',
+ content: system,
+ },
+ ...messages,
+ ],
+ };
+ }
+ return parsedTraceCall;
+};
diff --git a/weave/scorers/base_scorer.py b/weave/scorers/base_scorer.py
index 5a19adcd04f..4ac27f1a76b 100644
--- a/weave/scorers/base_scorer.py
+++ b/weave/scorers/base_scorer.py
@@ -1,4 +1,5 @@
import inspect
+import textwrap
from collections.abc import Sequence
from numbers import Number
from typing import Any, Callable, Optional, Union
@@ -45,7 +46,13 @@ def _validate_scorer_signature(scorer: Union[Callable, Op, Scorer]) -> bool:
params = inspect.signature(scorer).parameters
if "output" in params and "model_output" in params:
raise ValueError(
- "Both 'output' and 'model_output' cannot be in the scorer signature; prefer just using `output`."
+ textwrap.dedent(
+ """
+ The scorer signature cannot include both `output` and `model_output` at the same time.
+
+ To resolve, rename one of the arguments to avoid conflict. Prefer using `output` as the model's output.
+ """
+ )
)
return True
diff --git a/weave/trace/context/call_context.py b/weave/trace/context/call_context.py
index 402e1843ade..3a03bd167c3 100644
--- a/weave/trace/context/call_context.py
+++ b/weave/trace/context/call_context.py
@@ -20,6 +20,8 @@ class NoCurrentCallError(Exception): ...
logger = logging.getLogger(__name__)
+_tracing_enabled = contextvars.ContextVar("tracing_enabled", default=True)
+
def push_call(call: Call) -> None:
new_stack = copy.copy(_call_stack.get())
@@ -136,3 +138,22 @@ def set_call_stack(stack: list[Call]) -> Iterator[list[Call]]:
call_attributes: contextvars.ContextVar[dict[str, Any]] = contextvars.ContextVar(
"call_attributes", default={}
)
+
+
+def get_tracing_enabled() -> bool:
+ return _tracing_enabled.get()
+
+
+@contextlib.contextmanager
+def set_tracing_enabled(enabled: bool) -> Iterator[None]:
+ token = _tracing_enabled.set(enabled)
+ try:
+ yield
+ finally:
+ _tracing_enabled.reset(token)
+
+
+@contextlib.contextmanager
+def tracing_disabled() -> Iterator[None]:
+ with set_tracing_enabled(False):
+ yield
diff --git a/weave/trace/op.py b/weave/trace/op.py
index 2b5835474d8..a89c7400d8b 100644
--- a/weave/trace/op.py
+++ b/weave/trace/op.py
@@ -4,6 +4,7 @@
import inspect
import logging
+import random
import sys
import traceback
from collections.abc import Coroutine, Mapping
@@ -26,7 +27,11 @@
from weave.trace.constants import TRACE_CALL_EMOJI
from weave.trace.context import call_context
from weave.trace.context import weave_client_context as weave_client_context
-from weave.trace.context.call_context import call_attributes
+from weave.trace.context.call_context import (
+ call_attributes,
+ get_tracing_enabled,
+ tracing_disabled,
+)
from weave.trace.context.tests_context import get_raise_on_captured_errors
from weave.trace.errors import OpCallError
from weave.trace.refs import ObjectRef
@@ -174,6 +179,8 @@ class Op(Protocol):
# it disables child ops as well.
_tracing_enabled: bool
+ tracing_sample_rate: float
+
def _set_on_input_handler(func: Op, on_input: OnInputHandlerType) -> None:
if func._on_input_handler is not None:
@@ -407,37 +414,54 @@ def _do_call(
if not pargs:
pargs = _default_on_input_handler(op, args, kwargs)
+ # Handle all of the possible cases where we would skip tracing.
if settings.should_disable_weave():
res = func(*pargs.args, **pargs.kwargs)
- elif weave_client_context.get_weave_client() is None:
+ return res, call
+ if weave_client_context.get_weave_client() is None:
+ res = func(*pargs.args, **pargs.kwargs)
+ return res, call
+ if not op._tracing_enabled:
+ res = func(*pargs.args, **pargs.kwargs)
+ return res, call
+ if not get_tracing_enabled():
res = func(*pargs.args, **pargs.kwargs)
- elif not op._tracing_enabled:
+ return res, call
+
+ current_call = call_context.get_current_call()
+ if current_call is None:
+ # Root call: decide whether to trace based on sample rate
+ if random.random() > op.tracing_sample_rate:
+ # Disable tracing for this call and all descendants
+ with tracing_disabled():
+ res = func(*pargs.args, **pargs.kwargs)
+ return res, call
+
+ # Proceed with tracing. Note that we don't check the sample rate here.
+ # Only root calls get sampling applied.
+ # If the parent was traced (sampled in), the child will be too.
+ try:
+ call = _create_call(op, *args, __weave=__weave, **kwargs)
+ except OpCallError as e:
+ raise e
+ except Exception as e:
+ if get_raise_on_captured_errors():
+ raise
+ log_once(
+ logger.error,
+ CALL_CREATE_MSG.format(traceback.format_exc()),
+ )
res = func(*pargs.args, **pargs.kwargs)
else:
- try:
- # This try/except allows us to fail gracefully and
- # still let the user code continue to execute
- call = _create_call(op, *args, __weave=__weave, **kwargs)
- except OpCallError as e:
- raise e
- except Exception as e:
- if get_raise_on_captured_errors():
- raise
- log_once(
- logger.error,
- CALL_CREATE_MSG.format(traceback.format_exc()),
- )
- res = func(*pargs.args, **pargs.kwargs)
- else:
- execute_result = _execute_op(
- op, call, *pargs.args, __should_raise=__should_raise, **pargs.kwargs
+ execute_result = _execute_op(
+ op, call, *pargs.args, __should_raise=__should_raise, **pargs.kwargs
+ )
+ if inspect.iscoroutine(execute_result):
+ raise TypeError(
+ "Internal error: Expected `_execute_call` to return a sync result"
)
- if inspect.iscoroutine(execute_result):
- raise TypeError(
- "Internal error: Expected `_execute_call` to return a sync result"
- )
- execute_result = cast(tuple[Any, "Call"], execute_result)
- res, call = execute_result
+ execute_result = cast(tuple[Any, "Call"], execute_result)
+ res, call = execute_result
return res, call
@@ -450,39 +474,52 @@ async def _do_call_async(
) -> tuple[Any, Call]:
func = op.resolve_fn
call = _placeholder_call()
+
+ # Handle all of the possible cases where we would skip tracing.
if settings.should_disable_weave():
res = await func(*args, **kwargs)
- elif weave_client_context.get_weave_client() is None:
+ return res, call
+ if weave_client_context.get_weave_client() is None:
res = await func(*args, **kwargs)
- elif not op._tracing_enabled:
+ return res, call
+ if not op._tracing_enabled:
+ res = await func(*args, **kwargs)
+ return res, call
+ if not get_tracing_enabled():
+ res = await func(*args, **kwargs)
+ return res, call
+
+ current_call = call_context.get_current_call()
+ if current_call is None:
+ # Root call: decide whether to trace based on sample rate
+ if random.random() > op.tracing_sample_rate:
+ # Disable tracing for this call and all descendants
+ with tracing_disabled():
+ res = await func(*args, **kwargs)
+ return res, call
+
+ # Proceed with tracing
+ try:
+ call = _create_call(op, *args, __weave=__weave, **kwargs)
+ except OpCallError as e:
+ raise e
+ except Exception as e:
+ if get_raise_on_captured_errors():
+ raise
+ log_once(
+ logger.error,
+ ASYNC_CALL_CREATE_MSG.format(traceback.format_exc()),
+ )
res = await func(*args, **kwargs)
else:
- try:
- # This try/except allows us to fail gracefully and
- # still let the user code continue to execute
- call = _create_call(op, *args, __weave=__weave, **kwargs)
- except OpCallError as e:
- raise e
- except Exception as e:
- if get_raise_on_captured_errors():
- raise
- log_once(
- logger.error,
- ASYNC_CALL_CREATE_MSG.format(traceback.format_exc()),
- )
- res = await func(*args, **kwargs)
- else:
- execute_result = _execute_op(
- op, call, *args, __should_raise=__should_raise, **kwargs
- )
- if not inspect.iscoroutine(execute_result):
- raise TypeError(
- "Internal error: Expected `_execute_call` to return a coroutine"
- )
- execute_result = cast(
- Coroutine[Any, Any, tuple[Any, "Call"]], execute_result
+ execute_result = _execute_op(
+ op, call, *args, __should_raise=__should_raise, **kwargs
+ )
+ if not inspect.iscoroutine(execute_result):
+ raise TypeError(
+ "Internal error: Expected `_execute_call` to return a coroutine"
)
- res, call = await execute_result
+ res, call = await execute_result
return res, call
@@ -540,6 +577,7 @@ def op(
call_display_name: str | CallDisplayNameFunc | None = None,
postprocess_inputs: PostprocessInputsFunc | None = None,
postprocess_output: PostprocessOutputFunc | None = None,
+ tracing_sample_rate: float = 1.0,
) -> Callable[[Callable], Op] | Op:
"""
A decorator to weave op-ify a function or method. Works for both sync and async.
@@ -565,6 +603,7 @@ def op(
postprocess_output (Optional[Callable[..., Any]]): A function to process the output
after it's been returned from the function but before it's logged. This does not
affect the actual output of the function, only the displayed output.
+ tracing_sample_rate (float): The sampling rate for tracing this function. Defaults to 1.0 (always trace).
Returns:
Union[Callable[[Any], Op], Op]: If called without arguments, returns a decorator.
@@ -591,6 +630,10 @@ async def extract():
await extract() # calls the function and tracks the call in the Weave UI
```
"""
+ if not isinstance(tracing_sample_rate, (int, float)):
+ raise TypeError("tracing_sample_rate must be a float")
+ if not 0 <= tracing_sample_rate <= 1:
+ raise ValueError("tracing_sample_rate must be between 0 and 1")
def op_deco(func: Callable) -> Op:
# Check function type
@@ -647,6 +690,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
wrapper._on_finish_handler = None # type: ignore
wrapper._tracing_enabled = True # type: ignore
+ wrapper.tracing_sample_rate = tracing_sample_rate # type: ignore
wrapper.get_captured_code = partial(get_captured_code, wrapper) # type: ignore