Skip to content

Commit

Permalink
Merge branch 'master' into DOCS-1128
Browse files Browse the repository at this point in the history
  • Loading branch information
J2-D2-3PO authored Dec 15, 2024
2 parents daf098d + 1571732 commit 2684774
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 38 deletions.
9 changes: 7 additions & 2 deletions tests/trace/test_client_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import datetime
import json
import platform
import random
import sys
import time
from collections import defaultdict, namedtuple
Expand Down Expand Up @@ -3005,6 +3006,8 @@ def test_op_sampling(client):
always_traced_calls = 0
sometimes_traced_calls = 0

random.seed(0)

@weave.op(tracing_sample_rate=0.0)
def never_traced(x: int) -> int:
nonlocal never_traced_calls
Expand Down Expand Up @@ -3044,14 +3047,16 @@ def sometimes_traced(x: int) -> int:
sometimes_traced(i)
assert sometimes_traced_calls == num_runs # Function was called every time
num_traces = len(list(sometimes_traced.calls()))
assert 35 < num_traces < 65 # But only traced ~50% of the time
assert num_traces == 38


def test_op_sampling_async(client):
never_traced_calls = 0
always_traced_calls = 0
sometimes_traced_calls = 0

random.seed(0)

@weave.op(tracing_sample_rate=0.0)
async def never_traced(x: int) -> int:
nonlocal never_traced_calls
Expand Down Expand Up @@ -3092,7 +3097,7 @@ async def sometimes_traced(x: int) -> int:
asyncio.run(sometimes_traced(i))
assert sometimes_traced_calls == num_runs # Function was called every time
num_traces = len(list(sometimes_traced.calls()))
assert 35 < num_traces < 65 # But only traced ~50% of the time
assert num_traces == 38


def test_op_sampling_inheritance(client):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,31 +1,19 @@
import {Button} from '@wandb/weave/components/Button';
import {Pill} from '@wandb/weave/components/Tag';
import {Tailwind} from '@wandb/weave/components/Tailwind';
import {Tooltip} from '@wandb/weave/components/Tooltip';
import {makeRefCall} from '@wandb/weave/util/refs';
import React from 'react';
import {useHistory} from 'react-router-dom';

import {useWeaveflowRouteContext} from '../../../context';
import {Reactions} from '../../../feedback/Reactions';
import {TraceCostStats} from '../../CallPage/cost';
import {TraceCallSchema} from '../../wfReactInterface/traceServerClientTypes';

export const PlaygroundCallStats = ({call}: {call: TraceCallSchema}) => {
let totalTokens = 0;
if (call?.summary?.usage) {
for (const key of Object.keys(call.summary.usage)) {
totalTokens +=
call.summary.usage[key].prompt_tokens ||
call.summary.usage[key].input_tokens ||
0;
totalTokens +=
call.summary.usage[key].completion_tokens ||
call.summary.usage[key].output_tokens ||
0;
}
}

const [entityName, projectName] = call?.project_id?.split('/') || [];
const callId = call?.id || '';
const latency = call?.summary?.weave?.latency_ms;
const {peekingRouter} = useWeaveflowRouteContext();
const history = useHistory();

Expand All @@ -43,21 +31,35 @@ export const PlaygroundCallStats = ({call}: {call: TraceCallSchema}) => {
false
);

const latency = call?.summary?.weave?.latency_ms ?? 0;
const usageData = call?.summary?.usage;
const costData = call?.summary?.weave?.costs;

return (
<Tailwind>
<div className="flex w-full flex-wrap items-center justify-center gap-8 py-8 text-sm text-moon-500">
<span>Latency: {latency}ms</span>
<span></span>
<div className="flex w-full items-center justify-center gap-8 py-8">
<TraceCostStats
usageData={usageData}
costData={costData}
latency_ms={latency}
costLoading={false}
/>
{(call.output as any)?.choices?.[0]?.finish_reason && (
<>
<span>
Finish reason: {(call.output as any).choices[0].finish_reason}
</span>
<span></span>
</>
<Tooltip
content="Finish reason"
trigger={
// Placing in span so tooltip shows up
<span>
<Pill
icon="checkmark-circle"
label={(call.output as any).choices[0].finish_reason}
color="moon"
className="-ml-[8px] bg-transparent text-moon-500 dark:bg-transparent dark:text-moon-500"
/>
</span>
}
/>
)}
<span>{totalTokens} tokens</span>
<span></span>
{callLink && (
<Button
size="small"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,28 @@ export const PlaygroundPageInner = (props: PlaygroundPageProps) => {
callId: props.callId,
}
: null;
}, [props.entity, props.project, props.callId])
}, [props.entity, props.project, props.callId]),
{
includeCosts: true,
}
);

const {result: calls} = useCalls(props.entity, props.project, {
callIds: playgroundStates.map(state => state.traceCall.id || ''),
});
const {result: calls} = useCalls(
props.entity,
props.project,
{
callIds: playgroundStates.map(state => state.traceCall.id || ''),
},
undefined,
undefined,
undefined,
undefined,
undefined,
undefined,
{
includeCosts: true,
}
);

useEffect(() => {
if (!call.loading && call.result) {
Expand Down
6 changes: 3 additions & 3 deletions weave/trace_server/clickhouse_trace_server_batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -1447,8 +1447,8 @@ def completions_create(
if not secret_name:
raise InvalidRequest(f"No secret name found for model {model_name}")
api_key = secret_fetcher.fetch(secret_name).get("secrets", {}).get(secret_name)
isBedrock = model_info.get("litellm_provider") == "bedrock"
if not api_key and not isBedrock:
provider = model_info.get("litellm_provider")
if not api_key and provider != "bedrock" and provider != "bedrock_converse":
raise MissingLLMApiKeyError(
f"No API key {secret_name} found for model {model_name}",
api_key_name=secret_name,
Expand All @@ -1458,7 +1458,7 @@ def completions_create(
res = lite_llm_completion(
api_key,
req.inputs,
isBedrock,
provider,
)
end_time = datetime.datetime.now()

Expand Down
13 changes: 11 additions & 2 deletions weave/trace_server/llm_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,26 @@
)
from weave.trace_server.secret_fetcher_context import _secret_fetcher_context

NOVA_MODELS = ("nova-pro-v1", "nova-lite-v1", "nova-micro-v1")


def lite_llm_completion(
api_key: str,
inputs: tsi.CompletionsCreateRequestInputs,
isBedrock: bool,
provider: Optional[str] = None,
) -> tsi.CompletionsCreateRes:
aws_access_key_id, aws_secret_access_key, aws_region_name = None, None, None
if isBedrock:
if provider == "bedrock" or provider == "bedrock_converse":
aws_access_key_id, aws_secret_access_key, aws_region_name = (
get_bedrock_credentials(inputs.model)
)
# Nova models need the region in the model name
if any(x in inputs.model for x in NOVA_MODELS) and aws_region_name:
aws_inference_region = aws_region_name.split("-")[0]
inputs.model = "bedrock/" + aws_inference_region + "." + inputs.model
# XAI models don't support response_format
elif provider == "xai":
inputs.response_format = None

import litellm

Expand Down
2 changes: 1 addition & 1 deletion weave/trace_server/model_providers/model_providers.json

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions weave/trace_server/model_providers/model_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
"fireworks": "FIREWORKS_API_KEY",
"groq": "GEMMA_API_KEY",
"bedrock": "BEDROCK_API_KEY",
"bedrock_converse": "BEDROCK_API_KEY",
"xai": "XAI_API_KEY",
}


Expand Down

0 comments on commit 2684774

Please sign in to comment.