diff --git a/tests/trace/test_table_query.py b/tests/trace/test_table_query.py index 1a45b3963e3..fe9e3e2320a 100644 --- a/tests/trace/test_table_query.py +++ b/tests/trace/test_table_query.py @@ -1,4 +1,5 @@ import random +from typing import Iterator from weave.trace.weave_client import WeaveClient from weave.trace_server import trace_server_interface as tsi @@ -55,6 +56,28 @@ def test_table_query(client: WeaveClient): assert result_digests == row_digests +def test_table_query_stream(client: WeaveClient): + digest, row_digests, data = generate_table_data(client, 10, 10) + + res = client.server.table_query_stream( + tsi.TableQueryReq( + project_id=client._project_id(), + digest=digest, + ) + ) + + assert isinstance(res, Iterator) + rows = [] + for r in res: + rows.append(r) + + result_vals = [r.val for r in rows] + result_digests = [r.digest for r in rows] + + assert result_vals == data + assert result_digests == row_digests + + def test_table_query_invalid_digest(client: WeaveClient): res = client.server.table_query( tsi.TableQueryReq( diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/callsTableQuery.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/callsTableQuery.ts index 87a5ec9423a..b80a9905f9d 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/callsTableQuery.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/callsTableQuery.ts @@ -33,7 +33,8 @@ export const useCallsForQuery = ( gridFilter: GridFilterModel, gridSort: GridSortModel, gridPage: GridPaginationModel, - expandedColumns: Set + expandedColumns: Set, + columns?: string[] ): { result: CallSchema[]; loading: boolean; @@ -57,7 +58,7 @@ export const useCallsForQuery = ( offset, sortBy, filterBy, - undefined, + columns, expandedColumns, { refetchOnDelete: true, diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/CompareEvaluationsPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/CompareEvaluationsPage.tsx index a269463ff4e..05da3ba49ed 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/CompareEvaluationsPage.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/CompareEvaluationsPage.tsx @@ -95,7 +95,9 @@ export const CompareEvaluationsPageContent: React.FC< ); React.useEffect(() => { - if (props.evaluationCallIds.length > 0) { + // Only update the baseline if we are switching evaluations, if there + // is more than 1, we are in the compare view and baseline is auto set + if (props.evaluationCallIds.length === 1) { setBaselineEvaluationCallId(props.evaluationCallIds[0]); } }, [props.evaluationCallIds]); @@ -108,7 +110,7 @@ export const CompareEvaluationsPageContent: React.FC< >; setSelectedInputDigest: React.Dispatch>; + addEvaluationCall: (newCallId: string) => void; + removeEvaluationCall: (callId: string) => void; } | null>(null); export const useCompareEvaluationsState = () => { @@ -29,7 +31,7 @@ export const useCompareEvaluationsState = () => { export const CompareEvaluationsProvider: React.FC<{ entity: string; project: string; - evaluationCallIds: string[]; + initialEvaluationCallIds: string[]; setBaselineEvaluationCallId: React.Dispatch< React.SetStateAction >; @@ -43,7 +45,7 @@ export const CompareEvaluationsProvider: React.FC<{ }> = ({ entity, project, - evaluationCallIds, + initialEvaluationCallIds, setBaselineEvaluationCallId, setComparisonDimensions, @@ -54,6 +56,9 @@ export const CompareEvaluationsProvider: React.FC<{ selectedInputDigest, children, }) => { + const [evaluationCallIds, setEvaluationCallIds] = useState( + initialEvaluationCallIds + ); const initialState = useEvaluationComparisonState( entity, project, @@ -72,10 +77,17 @@ export const CompareEvaluationsProvider: React.FC<{ setBaselineEvaluationCallId, setComparisonDimensions, setSelectedInputDigest, + addEvaluationCall: (newCallId: string) => { + setEvaluationCallIds(prev => [...prev, newCallId]); + }, + removeEvaluationCall: (callId: string) => { + setEvaluationCallIds(prev => prev.filter(id => id !== callId)); + }, }; }, [ initialState.loading, initialState.result, + setEvaluationCallIds, setBaselineEvaluationCallId, setComparisonDimensions, setSelectedInputDigest, diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ComparisonDefinitionSection/ComparisonDefinitionSection.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ComparisonDefinitionSection/ComparisonDefinitionSection.tsx index a7fb35f1746..3d461681a3c 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ComparisonDefinitionSection/ComparisonDefinitionSection.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ComparisonDefinitionSection/ComparisonDefinitionSection.tsx @@ -1,12 +1,28 @@ -import React, {useMemo} from 'react'; +import {Popover} from '@mui/material'; +import Input from '@wandb/weave/common/components/Input'; +import {Tailwind} from '@wandb/weave/components/Tailwind'; +import {parseRef, WeaveObjectRef} from '@wandb/weave/react'; +import React, {useEffect, useMemo, useRef, useState} from 'react'; import {Button} from '../../../../../../../Button'; +import { + DEFAULT_FILTER_CALLS, + DEFAULT_SORT_CALLS, +} from '../../../CallsPage/CallsTable'; +import {useCallsForQuery} from '../../../CallsPage/callsTableQuery'; +import {useEvaluationsFilter} from '../../../CallsPage/evaluationsFilter'; +import {Id} from '../../../common/Id'; +import {useWFHooks} from '../../../wfReactInterface/context'; +import { + CallSchema, + ObjectVersionKey, +} from '../../../wfReactInterface/wfDataModelHooksInterface'; import {useCompareEvaluationsState} from '../../compareEvaluationsContext'; import {STANDARD_PADDING} from '../../ecpConstants'; import {getOrderedCallIds} from '../../ecpState'; import {EvaluationComparisonState} from '../../ecpState'; import {HorizontalBox} from '../../Layout'; -import {EvaluationDefinition} from './EvaluationDefinition'; +import {EvaluationDefinition, VerticalBar} from './EvaluationDefinition'; export const ComparisonDefinitionSection: React.FC<{ state: EvaluationComparisonState; @@ -28,25 +44,200 @@ export const ComparisonDefinitionSection: React.FC<{ {evalCallIds.map((key, ndx) => { return ( - {ndx !== 0 && } ); })} + ); }; -const SwapPositionsButton: React.FC<{callId: string}> = props => { - const {setBaselineEvaluationCallId} = useCompareEvaluationsState(); +const ModelRefLabel: React.FC<{modelRef: string}> = props => { + const {useObjectVersion} = useWFHooks(); + const objRef = useMemo( + () => parseRef(props.modelRef) as WeaveObjectRef, + [props.modelRef] + ); + const objVersionKey = useMemo(() => { + return { + scheme: 'weave', + entity: objRef.entityName, + project: objRef.projectName, + weaveKind: objRef.weaveKind, + objectId: objRef.artifactName, + versionHash: objRef.artifactVersion, + path: '', + refExtra: objRef.artifactRefExtra, + } as ObjectVersionKey; + }, [ + objRef.artifactName, + objRef.artifactRefExtra, + objRef.artifactVersion, + objRef.entityName, + objRef.projectName, + objRef.weaveKind, + ]); + const objectVersion = useObjectVersion(objVersionKey); + return ( + + {objectVersion.result?.objectId}:{objectVersion.result?.versionIndex} + + ); +}; + +const AddEvaluationButton: React.FC<{ + state: EvaluationComparisonState; +}> = props => { + const {addEvaluationCall} = useCompareEvaluationsState(); + + // Calls query for just evaluations + const evaluationsFilter = useEvaluationsFilter( + props.state.data.entity, + props.state.data.project + ); + const page = useMemo( + () => ({ + pageSize: 100, + page: 0, + }), + [] + ); + const expandedRefCols = useMemo(() => new Set(), []); + // Don't query for output here, re-queried in tsDataModelHooksEvaluationComparison.ts + const columns = useMemo(() => ['inputs'], []); + const calls = useCallsForQuery( + props.state.data.entity, + props.state.data.project, + evaluationsFilter, + DEFAULT_FILTER_CALLS, + DEFAULT_SORT_CALLS, + page, + expandedRefCols, + columns + ); + + const evalsNotComparing = useMemo(() => { + return calls.result.filter( + call => + !Object.keys(props.state.data.evaluationCalls).includes(call.callId) + ); + }, [calls.result, props.state.data.evaluationCalls]); + + const [menuOptions, setMenuOptions] = + useState(evalsNotComparing); + useEffect(() => { + setMenuOptions(evalsNotComparing); + }, [evalsNotComparing]); + + const onSearchChange = (e: React.ChangeEvent) => { + const search = e.target.value; + if (search === '') { + setMenuOptions(evalsNotComparing); + return; + } + + const filteredOptions = calls.result.filter(call => { + if ( + (call.displayName ?? call.spanName) + .toLowerCase() + .includes(search.toLowerCase()) + ) { + return true; + } + if (call.callId.slice(-4).includes(search)) { + return true; + } + const modelRef = parseRef(call.traceCall?.inputs.model) as WeaveObjectRef; + if (modelRef.artifactName.toLowerCase().includes(search.toLowerCase())) { + return true; + } + return false; + }); + + setMenuOptions(filteredOptions); + }; + + // Popover management + const refBar = useRef(null); + const refLabel = useRef(null); + const [anchorEl, setAnchorEl] = React.useState(null); + const onClick = (event: React.MouseEvent) => { + setAnchorEl(anchorEl ? null : refBar.current); + }; + const open = Boolean(anchorEl); + const id = open ? 'simple-popper' : undefined; + return ( - + + + setAnchorEl(null)}> + +
+ +
+ {menuOptions.length === 0 && ( +
No evaluations
+ )} + {menuOptions.map(call => ( +
+ +
+ ))} +
+
+
+
+ ); }; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ComparisonDefinitionSection/EvaluationDefinition.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ComparisonDefinitionSection/EvaluationDefinition.tsx index 576b17f5822..7fdb0c6c948 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ComparisonDefinitionSection/EvaluationDefinition.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ComparisonDefinitionSection/EvaluationDefinition.tsx @@ -1,5 +1,8 @@ import {Box} from '@material-ui/core'; import {Circle} from '@mui/icons-material'; +import {PopupDropdown} from '@wandb/weave/common/components/PopupDropdown'; +import {Button} from '@wandb/weave/components/Button'; +import {Pill} from '@wandb/weave/components/Tag'; import React, {useMemo} from 'react'; import { @@ -14,6 +17,7 @@ import {SmallRef} from '../../../../../Browse2/SmallRef'; import {CallLink, ObjectVersionLink} from '../../../common/Links'; import {useWFHooks} from '../../../wfReactInterface/context'; import {ObjectVersionKey} from '../../../wfReactInterface/wfDataModelHooksInterface'; +import {useCompareEvaluationsState} from '../../compareEvaluationsContext'; import { BOX_RADIUS, CIRCLE_SIZE, @@ -27,6 +31,36 @@ export const EvaluationDefinition: React.FC<{ state: EvaluationComparisonState; callId: string; }> = props => { + const {removeEvaluationCall, setBaselineEvaluationCallId} = + useCompareEvaluationsState(); + + const menuOptions = useMemo(() => { + return [ + { + key: 'add-to-baseline', + content: 'Set as baseline', + onClick: () => { + setBaselineEvaluationCallId(props.callId); + }, + disabled: props.callId === props.state.baselineEvaluationCallId, + }, + { + key: 'remove', + content: 'Remove', + onClick: () => { + removeEvaluationCall(props.callId); + }, + disabled: Object.keys(props.state.data.evaluationCalls).length === 1, + }, + ]; + }, [ + props.callId, + props.state.baselineEvaluationCallId, + props.state.data.evaluationCalls, + removeEvaluationCall, + setBaselineEvaluationCallId, + ]); + return ( - - + {props.callId === props.state.baselineEvaluationCallId && ( + + )} + + } + /> ); }; + export const EvaluationCallLink: React.FC<{ callId: string; state: EvaluationComparisonState; @@ -147,11 +195,12 @@ const ModelIcon: React.FC = () => { ); }; -const VerticalBar: React.FC = () => { + +export const VerticalBar: React.FC = () => { return (
"ClickHouseTraceServer": - return cls( + # Explicitly calling `RemoteHTTPTraceServer` constructor here to ensure + # that type checking is applied to the constructor. + return ClickHouseTraceServer( host=wf_env.wf_clickhouse_host(), port=wf_env.wf_clickhouse_port(), user=wf_env.wf_clickhouse_user(), @@ -770,6 +772,12 @@ def add_new_row_needed_to_insert(row_data: Any) -> str: return tsi.TableUpdateRes(digest=digest, updated_row_digests=updated_digests) def table_query(self, req: tsi.TableQueryReq) -> tsi.TableQueryRes: + rows = list(self.table_query_stream(req)) + return tsi.TableQueryRes(rows=rows) + + def table_query_stream( + self, req: tsi.TableQueryReq + ) -> Iterator[tsi.TableRowSchema]: conds = [] pb = ParamBuilder() if req.filter: @@ -790,7 +798,8 @@ def table_query(self, req: tsi.TableQueryReq) -> tsi.TableQueryRes: direction="ASC" if sort.direction.lower() == "asc" else "DESC", ) sort_fields.append(field) - rows = self._table_query( + + rows = self._table_query_stream( req.project_id, req.digest, pb, @@ -799,9 +808,10 @@ def table_query(self, req: tsi.TableQueryReq) -> tsi.TableQueryRes: limit=req.limit, offset=req.offset, ) - return tsi.TableQueryRes(rows=rows) + for row in rows: + yield row - def _table_query( + def _table_query_stream( self, project_id: str, digest: str, @@ -813,7 +823,7 @@ def _table_query( sort_fields: Optional[list[OrderField]] = None, limit: Optional[int] = None, offset: Optional[int] = None, - ) -> list[tsi.TableRowSchema]: + ) -> Iterator[tsi.TableRowSchema]: if not sort_fields: sort_fields = [ OrderField( @@ -850,12 +860,10 @@ def _table_query( offset=offset, ) - query_result = self.ch_client.query(query, parameters=pb.get_params()) + res = self._query_stream(query, parameters=pb.get_params()) - return [ - tsi.TableRowSchema(digest=r[0], val=json.loads(r[1])) - for r in query_result.result_rows - ] + for row in res: + yield tsi.TableRowSchema(digest=row[0], val=json.loads(row[1])) def table_query_stats(self, req: tsi.TableQueryStatsReq) -> tsi.TableQueryStatsRes: parameters: Dict[str, Any] = { @@ -1124,7 +1132,7 @@ def resolve_extra(extra: list[str], val: Any) -> PartialRefResult: raise ValueError("Will not resolve cross-project refs.") pb = ParamBuilder() row_digests_name = pb.add_param(row_digests) - rows = self._table_query( + rows_stream = self._table_query_stream( project_id=project_id_scope, digest=digest, pb=pb, @@ -1132,6 +1140,7 @@ def resolve_extra(extra: list[str], val: Any) -> PartialRefResult: f"digest IN {{{row_digests_name}: Array(String)}}" ], ) + rows = list(rows_stream) # Unpack the results into the target rows row_digest_vals = {r.digest: r.val for r in rows} for index, row_digest in index_digests: diff --git a/weave/trace_server/external_to_internal_trace_server_adapter.py b/weave/trace_server/external_to_internal_trace_server_adapter.py index 2f62d095464..588fd56dfa9 100644 --- a/weave/trace_server/external_to_internal_trace_server_adapter.py +++ b/weave/trace_server/external_to_internal_trace_server_adapter.py @@ -268,6 +268,14 @@ def table_query(self, req: tsi.TableQueryReq) -> tsi.TableQueryRes: req.project_id = self._idc.ext_to_int_project_id(req.project_id) return self._ref_apply(self._internal_trace_server.table_query, req) + def table_query_stream( + self, req: tsi.TableQueryReq + ) -> Iterator[tsi.TableRowSchema]: + req.project_id = self._idc.ext_to_int_project_id(req.project_id) + return self._stream_ref_apply( + self._internal_trace_server.table_query_stream, req + ) + def table_query_stats(self, req: tsi.TableQueryStatsReq) -> tsi.TableQueryStatsRes: req.project_id = self._idc.ext_to_int_project_id(req.project_id) return self._ref_apply(self._internal_trace_server.table_query_stats, req) diff --git a/weave/trace_server/sqlite_trace_server.py b/weave/trace_server/sqlite_trace_server.py index 202bfbe0c62..4df2fc19c56 100644 --- a/weave/trace_server/sqlite_trace_server.py +++ b/weave/trace_server/sqlite_trace_server.py @@ -1170,6 +1170,13 @@ def _select_objs_query( ) return result + def table_query_stream( + self, req: tsi.TableQueryReq + ) -> Iterator[tsi.TableRowSchema]: + results = self.table_query(req) + for row in results.rows: + yield row + def get_type(val: Any) -> str: if val == None: diff --git a/weave/trace_server/trace_server_interface.py b/weave/trace_server/trace_server_interface.py index 8b83e40ff54..442ed223cac 100644 --- a/weave/trace_server/trace_server_interface.py +++ b/weave/trace_server/trace_server_interface.py @@ -829,6 +829,7 @@ def objs_query(self, req: ObjQueryReq) -> ObjQueryRes: ... def table_create(self, req: TableCreateReq) -> TableCreateRes: ... def table_update(self, req: TableUpdateReq) -> TableUpdateRes: ... def table_query(self, req: TableQueryReq) -> TableQueryRes: ... + def table_query_stream(self, req: TableQueryReq) -> Iterator[TableRowSchema]: ... def table_query_stats(self, req: TableQueryStatsReq) -> TableQueryStatsRes: ... def refs_read_batch(self, req: RefsReadBatchReq) -> RefsReadBatchRes: ... def file_create(self, req: FileCreateReq) -> FileCreateRes: ... diff --git a/weave/trace_server_bindings/remote_http_trace_server.py b/weave/trace_server_bindings/remote_http_trace_server.py index fd8fd5d3fb6..94228574d25 100644 --- a/weave/trace_server_bindings/remote_http_trace_server.py +++ b/weave/trace_server_bindings/remote_http_trace_server.py @@ -118,7 +118,9 @@ def ensure_project_exists( @classmethod def from_env(cls, should_batch: bool = False) -> "RemoteHTTPTraceServer": - return cls(weave_trace_server_url(), should_batch) + # Explicitly calling `RemoteHTTPTraceServer` constructor here to ensure + # that type checking is applied to the constructor. + return RemoteHTTPTraceServer(weave_trace_server_url(), should_batch) def set_auth(self, auth: Tuple[str, str]) -> None: self._auth = auth @@ -439,6 +441,14 @@ def table_query( "/table/query", req, tsi.TableQueryReq, tsi.TableQueryRes ) + def table_query_stream( + self, req: tsi.TableQueryReq + ) -> Iterator[tsi.TableRowSchema]: + # Need to manually iterate over this until the stram endpoint is built and shipped. + res = self.table_query(req) + for row in res.rows: + yield row + def table_query_stats( self, req: Union[tsi.TableQueryStatsReq, dict[str, Any]] ) -> tsi.TableQueryStatsRes: