Skip to content

Commit

Permalink
Merge branch 'master' into a9-nd-nb-output-clear
Browse files Browse the repository at this point in the history
  • Loading branch information
acompa authored Oct 10, 2024
2 parents 3a3638e + cd1ac69 commit a91bf8a
Show file tree
Hide file tree
Showing 11 changed files with 349 additions and 36 deletions.
23 changes: 23 additions & 0 deletions tests/trace/test_table_query.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import random
from typing import Iterator

from weave.trace.weave_client import WeaveClient
from weave.trace_server import trace_server_interface as tsi
Expand Down Expand Up @@ -55,6 +56,28 @@ def test_table_query(client: WeaveClient):
assert result_digests == row_digests


def test_table_query_stream(client: WeaveClient):
digest, row_digests, data = generate_table_data(client, 10, 10)

res = client.server.table_query_stream(
tsi.TableQueryReq(
project_id=client._project_id(),
digest=digest,
)
)

assert isinstance(res, Iterator)
rows = []
for r in res:
rows.append(r)

result_vals = [r.val for r in rows]
result_digests = [r.digest for r in rows]

assert result_vals == data
assert result_digests == row_digests


def test_table_query_invalid_digest(client: WeaveClient):
res = client.server.table_query(
tsi.TableQueryReq(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ export const useCallsForQuery = (
gridFilter: GridFilterModel,
gridSort: GridSortModel,
gridPage: GridPaginationModel,
expandedColumns: Set<string>
expandedColumns: Set<string>,
columns?: string[]
): {
result: CallSchema[];
loading: boolean;
Expand All @@ -57,7 +58,7 @@ export const useCallsForQuery = (
offset,
sortBy,
filterBy,
undefined,
columns,
expandedColumns,
{
refetchOnDelete: true,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ export const CompareEvaluationsPageContent: React.FC<
);

React.useEffect(() => {
if (props.evaluationCallIds.length > 0) {
// Only update the baseline if we are switching evaluations, if there
// is more than 1, we are in the compare view and baseline is auto set
if (props.evaluationCallIds.length === 1) {
setBaselineEvaluationCallId(props.evaluationCallIds[0]);
}
}, [props.evaluationCallIds]);
Expand All @@ -108,7 +110,7 @@ export const CompareEvaluationsPageContent: React.FC<
<CompareEvaluationsProvider
entity={props.entity}
project={props.project}
evaluationCallIds={props.evaluationCallIds}
initialEvaluationCallIds={props.evaluationCallIds}
baselineEvaluationCallId={baselineEvaluationCallId ?? undefined}
comparisonDimensions={comparisonDimensions ?? undefined}
setBaselineEvaluationCallId={setBaselineEvaluationCallId}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import {Box} from '@material-ui/core';
import React, {useMemo} from 'react';
import React, {useMemo, useState} from 'react';

import {WeaveLoader} from '../../../../../../common/components/WeaveLoader';
import {LinearProgress} from '../../../../../LinearProgress';
Expand All @@ -16,6 +16,8 @@ const CompareEvaluationsContext = React.createContext<{
React.SetStateAction<ComparisonDimensionsType | null>
>;
setSelectedInputDigest: React.Dispatch<React.SetStateAction<string | null>>;
addEvaluationCall: (newCallId: string) => void;
removeEvaluationCall: (callId: string) => void;
} | null>(null);

export const useCompareEvaluationsState = () => {
Expand All @@ -29,7 +31,7 @@ export const useCompareEvaluationsState = () => {
export const CompareEvaluationsProvider: React.FC<{
entity: string;
project: string;
evaluationCallIds: string[];
initialEvaluationCallIds: string[];
setBaselineEvaluationCallId: React.Dispatch<
React.SetStateAction<string | null>
>;
Expand All @@ -43,7 +45,7 @@ export const CompareEvaluationsProvider: React.FC<{
}> = ({
entity,
project,
evaluationCallIds,
initialEvaluationCallIds,
setBaselineEvaluationCallId,
setComparisonDimensions,

Expand All @@ -54,6 +56,9 @@ export const CompareEvaluationsProvider: React.FC<{
selectedInputDigest,
children,
}) => {
const [evaluationCallIds, setEvaluationCallIds] = useState(
initialEvaluationCallIds
);
const initialState = useEvaluationComparisonState(
entity,
project,
Expand All @@ -72,10 +77,17 @@ export const CompareEvaluationsProvider: React.FC<{
setBaselineEvaluationCallId,
setComparisonDimensions,
setSelectedInputDigest,
addEvaluationCall: (newCallId: string) => {
setEvaluationCallIds(prev => [...prev, newCallId]);
},
removeEvaluationCall: (callId: string) => {
setEvaluationCallIds(prev => prev.filter(id => id !== callId));
},
};
}, [
initialState.loading,
initialState.result,
setEvaluationCallIds,
setBaselineEvaluationCallId,
setComparisonDimensions,
setSelectedInputDigest,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,28 @@
import React, {useMemo} from 'react';
import {Popover} from '@mui/material';
import Input from '@wandb/weave/common/components/Input';
import {Tailwind} from '@wandb/weave/components/Tailwind';
import {parseRef, WeaveObjectRef} from '@wandb/weave/react';
import React, {useEffect, useMemo, useRef, useState} from 'react';

import {Button} from '../../../../../../../Button';
import {
DEFAULT_FILTER_CALLS,
DEFAULT_SORT_CALLS,
} from '../../../CallsPage/CallsTable';
import {useCallsForQuery} from '../../../CallsPage/callsTableQuery';
import {useEvaluationsFilter} from '../../../CallsPage/evaluationsFilter';
import {Id} from '../../../common/Id';
import {useWFHooks} from '../../../wfReactInterface/context';
import {
CallSchema,
ObjectVersionKey,
} from '../../../wfReactInterface/wfDataModelHooksInterface';
import {useCompareEvaluationsState} from '../../compareEvaluationsContext';
import {STANDARD_PADDING} from '../../ecpConstants';
import {getOrderedCallIds} from '../../ecpState';
import {EvaluationComparisonState} from '../../ecpState';
import {HorizontalBox} from '../../Layout';
import {EvaluationDefinition} from './EvaluationDefinition';
import {EvaluationDefinition, VerticalBar} from './EvaluationDefinition';

export const ComparisonDefinitionSection: React.FC<{
state: EvaluationComparisonState;
Expand All @@ -28,25 +44,200 @@ export const ComparisonDefinitionSection: React.FC<{
{evalCallIds.map((key, ndx) => {
return (
<React.Fragment key={key}>
{ndx !== 0 && <SwapPositionsButton callId={key} />}
<EvaluationDefinition state={props.state} callId={key} />
</React.Fragment>
);
})}
<AddEvaluationButton state={props.state} />
</HorizontalBox>
);
};

const SwapPositionsButton: React.FC<{callId: string}> = props => {
const {setBaselineEvaluationCallId} = useCompareEvaluationsState();
const ModelRefLabel: React.FC<{modelRef: string}> = props => {
const {useObjectVersion} = useWFHooks();
const objRef = useMemo(
() => parseRef(props.modelRef) as WeaveObjectRef,
[props.modelRef]
);
const objVersionKey = useMemo(() => {
return {
scheme: 'weave',
entity: objRef.entityName,
project: objRef.projectName,
weaveKind: objRef.weaveKind,
objectId: objRef.artifactName,
versionHash: objRef.artifactVersion,
path: '',
refExtra: objRef.artifactRefExtra,
} as ObjectVersionKey;
}, [
objRef.artifactName,
objRef.artifactRefExtra,
objRef.artifactVersion,
objRef.entityName,
objRef.projectName,
objRef.weaveKind,
]);
const objectVersion = useObjectVersion(objVersionKey);
return (
<span className="ml-2">
{objectVersion.result?.objectId}:{objectVersion.result?.versionIndex}
</span>
);
};

const AddEvaluationButton: React.FC<{
state: EvaluationComparisonState;
}> = props => {
const {addEvaluationCall} = useCompareEvaluationsState();

// Calls query for just evaluations
const evaluationsFilter = useEvaluationsFilter(
props.state.data.entity,
props.state.data.project
);
const page = useMemo(
() => ({
pageSize: 100,
page: 0,
}),
[]
);
const expandedRefCols = useMemo(() => new Set<string>(), []);
// Don't query for output here, re-queried in tsDataModelHooksEvaluationComparison.ts
const columns = useMemo(() => ['inputs'], []);
const calls = useCallsForQuery(
props.state.data.entity,
props.state.data.project,
evaluationsFilter,
DEFAULT_FILTER_CALLS,
DEFAULT_SORT_CALLS,
page,
expandedRefCols,
columns
);

const evalsNotComparing = useMemo(() => {
return calls.result.filter(
call =>
!Object.keys(props.state.data.evaluationCalls).includes(call.callId)
);
}, [calls.result, props.state.data.evaluationCalls]);

const [menuOptions, setMenuOptions] =
useState<CallSchema[]>(evalsNotComparing);
useEffect(() => {
setMenuOptions(evalsNotComparing);
}, [evalsNotComparing]);

const onSearchChange = (e: React.ChangeEvent<HTMLInputElement>) => {
const search = e.target.value;
if (search === '') {
setMenuOptions(evalsNotComparing);
return;
}

const filteredOptions = calls.result.filter(call => {
if (
(call.displayName ?? call.spanName)
.toLowerCase()
.includes(search.toLowerCase())
) {
return true;
}
if (call.callId.slice(-4).includes(search)) {
return true;
}
const modelRef = parseRef(call.traceCall?.inputs.model) as WeaveObjectRef;
if (modelRef.artifactName.toLowerCase().includes(search.toLowerCase())) {
return true;
}
return false;
});

setMenuOptions(filteredOptions);
};

// Popover management
const refBar = useRef<HTMLDivElement>(null);
const refLabel = useRef<HTMLDivElement>(null);
const [anchorEl, setAnchorEl] = React.useState<null | HTMLElement>(null);
const onClick = (event: React.MouseEvent<HTMLElement>) => {
setAnchorEl(anchorEl ? null : refBar.current);
};
const open = Boolean(anchorEl);
const id = open ? 'simple-popper' : undefined;

return (
<Button
size="medium"
variant="quiet"
onClick={() => {
setBaselineEvaluationCallId(props.callId);
}}
icon="retry"
/>
<>
<div
ref={refBar}
className="flex cursor-pointer items-center gap-4 rounded px-8 py-4 outline outline-moon-250 hover:outline-2 hover:outline-teal-500/40"
onClick={onClick}>
<div ref={refLabel}>
<Button variant="ghost" size="large" icon="add-new">
Add evaluation
</Button>
</div>
</div>
<Popover
id={id}
open={open}
anchorEl={anchorEl}
anchorOrigin={{
vertical: 'bottom',
horizontal: 'left',
}}
transformOrigin={{
vertical: 'top',
horizontal: 'left',
}}
slotProps={{
paper: {
sx: {
marginTop: '8px',
overflow: 'visible',
minWidth: '200px',
},
},
}}
onClose={() => setAnchorEl(null)}>
<Tailwind>
<div className="w-full p-12">
<Input
type="text"
placeholder="Search"
icon="search"
iconPosition="left"
onChange={onSearchChange}
className="w-full"
/>
<div className="mt-12 flex max-h-[400px] flex-col gap-2 overflow-y-auto">
{menuOptions.length === 0 && (
<div className="text-center text-moon-600">No evaluations</div>
)}
{menuOptions.map(call => (
<div key={call.callId} className="flex items-center gap-2">
<Button
variant="ghost"
size="small"
className="pb-8 pt-8 font-['Source_Sans_Pro'] text-base font-normal text-moon-800"
onClick={() => {
addEvaluationCall(call.callId);
}}>
<>
<span>{call.displayName ?? call.spanName}</span>
<Id id={call.callId} type="Call" className="ml-0 mr-4" />
<VerticalBar />
<ModelRefLabel modelRef={call.traceCall?.inputs.model} />
</>
</Button>
</div>
))}
</div>
</div>
</Tailwind>
</Popover>
</>
);
};
Loading

0 comments on commit a91bf8a

Please sign in to comment.