diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx index 0635fbfbb50..1fb9ebb2087 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx @@ -928,6 +928,7 @@ const CompareEvaluationsBinding = () => { const evaluationCallIds = useMemo(() => { return JSON.parse(query.evaluationCallIds); }, [query.evaluationCallIds]); + const onEvaluationCallIdsUpdate = useCallback( (newEvaluationCallIds: string[]) => { const newQuery = new URLSearchParams(location.search); @@ -936,12 +937,27 @@ const CompareEvaluationsBinding = () => { }, [history, location.search] ); + + const selectedMetrics: Record | null = useMemo(() => { + try { + return JSON.parse(query.metrics); + } catch (e) { + return null; + } + }, [query.metrics]); + const setSelectedMetrics = (newModel: Record) => { + const newQuery = new URLSearchParams(location.search); + newQuery.set('metrics', JSON.stringify(newModel)); + history.push({search: newQuery.toString()}); + }; return ( ); }; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx index 8e179da6742..037c0f9c8e9 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx @@ -64,6 +64,9 @@ const useCallTabs = (call: CallSchema) => { entity={call.entity} project={call.project} evaluationCallIds={[call.callId]} + // Dont persist metric selection in the URL + selectedMetrics={{}} + setSelectedMetrics={() => {}} // Dont persist changes to evaluationCallIds in the URL onEvaluationCallIdsUpdate={() => {}} /> diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/CompareEvaluationsPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/CompareEvaluationsPage.tsx index 7cd3385f8ae..58148278cfb 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/CompareEvaluationsPage.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/CompareEvaluationsPage.tsx @@ -39,6 +39,8 @@ type CompareEvaluationsPageProps = { project: string; evaluationCallIds: string[]; onEvaluationCallIdsUpdate: (newEvaluationCallIds: string[]) => void; + selectedMetrics: Record | null; + setSelectedMetrics: (newModel: Record) => void; }; export const CompareEvaluationsPage: React.FC< @@ -57,6 +59,8 @@ export const CompareEvaluationsPage: React.FC< project={props.project} evaluationCallIds={props.evaluationCallIds} onEvaluationCallIdsUpdate={props.onEvaluationCallIdsUpdate} + selectedMetrics={props.selectedMetrics} + setSelectedMetrics={props.setSelectedMetrics} /> ), }, @@ -112,6 +116,8 @@ export const CompareEvaluationsPageContent: React.FC< = ({ const CompareEvaluationsPageInner: React.FC<{ height: number; }> = props => { - const {state} = useCompareEvaluationsState(); + const {state, setSelectedMetrics} = useCompareEvaluationsState(); const showExampleFilter = Object.keys(state.data.evaluationCalls).length === 2; const showExamples = Object.keys(state.data.resultRows).length > 0; @@ -200,7 +206,7 @@ const CompareEvaluationsPageInner: React.FC<{ evaluationCalls={Object.values(state.data.evaluationCalls)} /> - + {showExamples ? ( <> diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/compareEvaluationsContext.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/compareEvaluationsContext.tsx index 768ab2555a2..c72e07bc1fc 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/compareEvaluationsContext.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/compareEvaluationsContext.tsx @@ -16,6 +16,7 @@ const CompareEvaluationsContext = React.createContext<{ React.SetStateAction >; setSelectedInputDigest: React.Dispatch>; + setSelectedMetrics: (newModel: Record) => void; addEvaluationCall: (newCallId: string) => void; removeEvaluationCall: (callId: string) => void; } | null>(null); @@ -31,6 +32,9 @@ export const useCompareEvaluationsState = () => { export const CompareEvaluationsProvider: React.FC<{ entity: string; project: string; + selectedMetrics: Record | null; + setSelectedMetrics: (newModel: Record) => void; + initialEvaluationCallIds: string[]; onEvaluationCallIdsUpdate: (newEvaluationCallIds: string[]) => void; setBaselineEvaluationCallId: React.Dispatch< @@ -46,13 +50,15 @@ export const CompareEvaluationsProvider: React.FC<{ }> = ({ entity, project, + selectedMetrics, + setSelectedMetrics, + initialEvaluationCallIds, onEvaluationCallIdsUpdate, setBaselineEvaluationCallId, setComparisonDimensions, setSelectedInputDigest, - baselineEvaluationCallId, comparisonDimensions, selectedInputDigest, @@ -67,7 +73,8 @@ export const CompareEvaluationsProvider: React.FC<{ evaluationCallIds, baselineEvaluationCallId, comparisonDimensions, - selectedInputDigest + selectedInputDigest, + selectedMetrics ?? undefined ); const value = useMemo(() => { @@ -79,6 +86,7 @@ export const CompareEvaluationsProvider: React.FC<{ setBaselineEvaluationCallId, setComparisonDimensions, setSelectedInputDigest, + setSelectedMetrics, addEvaluationCall: (newCallId: string) => { const newEvaluationCallIds = [...evaluationCallIds, newCallId]; setEvaluationCallIds(newEvaluationCallIds); @@ -101,6 +109,7 @@ export const CompareEvaluationsProvider: React.FC<{ setBaselineEvaluationCallId, setComparisonDimensions, setSelectedInputDigest, + setSelectedMetrics, ]); if (!value) { diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/compositeMetricsUtil.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/compositeMetricsUtil.ts index c286541d234..a29fb74ffe9 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/compositeMetricsUtil.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/compositeMetricsUtil.ts @@ -75,7 +75,8 @@ export type CompositeSummaryMetricGroupForKeyPath = { */ export const buildCompositeMetricsMap = ( data: EvaluationComparisonData, - mType: MetricType + mType: MetricType, + selectedMetrics: Record | undefined = undefined ): CompositeScoreMetrics => { const composite: CompositeScoreMetrics = {}; @@ -93,6 +94,12 @@ export const buildCompositeMetricsMap = ( Object.entries(metricDefinitionMap).forEach(([metricId, metric]) => { const groupName = groupNameForMetric(metric); const ref = refForMetric(metric); + const keyPath = flattenedDimensionPath(metric); + + if (selectedMetrics && !selectedMetrics[keyPath]) { + // Skip metrics that are not in the selectedMetrics map + return; + } if (!composite[groupName]) { composite[groupName] = { @@ -105,8 +112,6 @@ export const buildCompositeMetricsMap = ( metricGroup.scorerRefs.push(ref); } - const keyPath = flattenedDimensionPath(metric); - if (!metricGroup.metrics[keyPath]) { metricGroup.metrics[keyPath] = { scorerAgnosticMetricDef: _.omit(metric, 'scorerOpOrObjRef'), diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpState.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpState.ts index ea28cc417e8..35f95dbf14f 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpState.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpState.ts @@ -24,6 +24,8 @@ export type EvaluationComparisonState = { comparisonDimensions?: ComparisonDimensionsType; // The current digest which is in view selectedInputDigest?: string; + // The selected metrics to display + selectedMetrics?: Record; }; export type ComparisonDimensionsType = Array<{ @@ -43,7 +45,8 @@ export const useEvaluationComparisonState = ( evaluationCallIds: string[], baselineEvaluationCallId?: string, comparisonDimensions?: ComparisonDimensionsType, - selectedInputDigest?: string + selectedInputDigest?: string, + selectedMetrics?: Record ): Loadable => { const data = useEvaluationComparisonData(entity, project, evaluationCallIds); @@ -93,6 +96,7 @@ export const useEvaluationComparisonState = ( baselineEvaluationCallId ?? evaluationCallIds[0], comparisonDimensions: newComparisonDimensions, selectedInputDigest, + selectedMetrics, }, }; }, [ @@ -102,6 +106,7 @@ export const useEvaluationComparisonState = ( evaluationCallIds, comparisonDimensions, selectedInputDigest, + selectedMetrics, ]); return value; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleCompareSection/ExampleCompareSection.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleCompareSection/ExampleCompareSection.tsx index 2573c7f4477..e0c3938f16e 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleCompareSection/ExampleCompareSection.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleCompareSection/ExampleCompareSection.tsx @@ -229,8 +229,13 @@ export const ExampleCompareSection: React.FC<{ const {ref1, ref2} = useLinkHorizontalScroll(); const compositeScoreMetrics = useMemo( - () => buildCompositeMetricsMap(props.state.data, 'score'), - [props.state.data] + () => + buildCompositeMetricsMap( + props.state.data, + 'score', + props.state.selectedMetrics + ), + [props.state.data, props.state.selectedMetrics] ); if (target == null) { @@ -252,9 +257,13 @@ export const ExampleCompareSection: React.FC<{ .length; }); const numEvals = numTrials.length; + // Get derived scores, then filter out any not in the selected metrics const derivedScores = Object.values( getMetricIds(props.state.data, 'score', 'derived') + ).filter( + score => props.state.selectedMetrics?.[flattenedDimensionPath(score)] ); + const numMetricScorers = metricGroupNames.length; const numDerivedScores = derivedScores.length; const numMetricsPerScorer = [ diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleCompareSection/exampleCompareSectionUtil.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleCompareSection/exampleCompareSectionUtil.ts index 59570f4c14e..bf20f6bffa4 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleCompareSection/exampleCompareSectionUtil.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleCompareSection/exampleCompareSectionUtil.ts @@ -118,8 +118,8 @@ const rowIsSelected = ( export const useFilteredAggregateRows = (state: EvaluationComparisonState) => { const leafDims = useMemo(() => getOrderedCallIds(state), [state]); const compositeMetricsMap = useMemo( - () => buildCompositeMetricsMap(state.data, 'score'), - [state.data] + () => buildCompositeMetricsMap(state.data, 'score', state.selectedMetrics), + [state.data, state.selectedMetrics] ); const flattenedRows = useMemo(() => { diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ScorecardSection/ScorecardSection.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ScorecardSection/ScorecardSection.tsx index 2422104f964..2b39cc1578d 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ScorecardSection/ScorecardSection.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ScorecardSection/ScorecardSection.tsx @@ -114,7 +114,11 @@ export const ScorecardSection: React.FC<{ const [diffOnly, setDiffOnly] = React.useState(true); const compositeSummaryMetrics = useMemo(() => { - return buildCompositeMetricsMap(props.state.data, 'summary'); + return buildCompositeMetricsMap( + props.state.data, + 'summary', + props.state.selectedMetrics + ); }, [props.state]); const onCallClick = usePeekCall( diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/SummaryPlotsSection.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/SummaryPlotsSection.tsx index 4e685d14f93..5bfaa8fcb04 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/SummaryPlotsSection.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/SummaryPlotsSection.tsx @@ -1,5 +1,16 @@ import {Box} from '@material-ui/core'; -import React, {useMemo} from 'react'; +import {Popover} from '@mui/material'; +import {Switch} from '@wandb/weave/components'; +import {Button} from '@wandb/weave/components/Button'; +import { + DraggableGrow, + DraggableHandle, +} from '@wandb/weave/components/DraggablePopups'; +import {TextField} from '@wandb/weave/components/Form/TextField'; +import {Tailwind} from '@wandb/weave/components/Tailwind'; +import {maybePluralize} from '@wandb/weave/core/util/string'; +import classNames from 'classnames'; +import React, {useEffect, useMemo, useRef, useState} from 'react'; import {buildCompositeMetricsMap} from '../../compositeMetricsUtil'; import { @@ -24,8 +35,44 @@ import {PlotlyRadarPlot, RadarPlotData} from './PlotlyRadarPlot'; */ export const SummaryPlots: React.FC<{ state: EvaluationComparisonState; + setSelectedMetrics: (newModel: Record) => void; }> = props => { - const plotlyRadarData = useNormalizedPlotDataFromMetrics(props.state); + const {radarData, allMetricNames} = useNormalizedPlotDataFromMetrics( + props.state + ); + const {selectedMetrics} = props.state; + const setSelectedMetrics = props.setSelectedMetrics; + + useEffect(() => { + // If selectedMetrics is null, we should show all metrics + if (selectedMetrics == null) { + setSelectedMetrics( + Object.fromEntries(Array.from(allMetricNames).map(m => [m, true])) + ); + } + }, [selectedMetrics, setSelectedMetrics, allMetricNames]); + + // filter down the plotlyRadarData to only include the selected metrics, after + // computation, to allow quick addition/removal of metrics + const filteredPlotlyRadarData = useMemo(() => { + const filteredData: RadarPlotData = {}; + for (const [callId, metricBin] of Object.entries(radarData)) { + const metrics: {[metric: string]: number} = {}; + for (const [metric, value] of Object.entries(metricBin.metrics)) { + if (selectedMetrics?.[metric]) { + metrics[metric] = value; + } + } + if (Object.keys(metrics).length > 0) { + filteredData[callId] = { + metrics, + name: metricBin.name, + color: metricBin.color, + }; + } + } + return filteredData; + }, [radarData, selectedMetrics]); return ( Summary Metrics + +
+
Configure displayed metrics
+ +
+
- + - +
); }; +const MetricsSelector: React.FC<{ + setSelectedMetrics: (newModel: Record) => void; + selectedMetrics: Record | undefined; + allMetrics: string[]; +}> = ({setSelectedMetrics, selectedMetrics, allMetrics}) => { + const [search, setSearch] = useState(''); + + const ref = useRef(null); + const [anchorEl, setAnchorEl] = useState(null); + const onClick = (event: React.MouseEvent) => { + setAnchorEl(anchorEl ? null : ref.current); + setSearch(''); + }; + const open = Boolean(anchorEl); + const id = open ? 'simple-popper' : undefined; + + const filteredCols = search + ? allMetrics.filter(col => col.toLowerCase().includes(search.toLowerCase())) + : allMetrics; + + const shownMetrics = Object.values(selectedMetrics ?? {}).filter(Boolean); + + const numHidden = allMetrics.length - shownMetrics.length; + const buttonSuffix = search ? `(${filteredCols.length})` : 'all'; + + return ( + <> + + +
+ +
+ + + + + ); +}; + const normalizeValues = (values: Array): number[] => { // find the max value // find the power of 2 that is greater than the max value @@ -93,7 +314,7 @@ const normalizeValues = (values: Array): number[] => { const useNormalizedPlotDataFromMetrics = ( state: EvaluationComparisonState -): RadarPlotData => { +): {radarData: RadarPlotData; allMetricNames: Set} => { const compositeMetrics = useMemo(() => { return buildCompositeMetricsMap(state.data, 'summary'); }, [state]); @@ -137,7 +358,7 @@ const useNormalizedPlotDataFromMetrics = ( evalScores, }; }); - return Object.fromEntries( + const radarData = Object.fromEntries( callIds.map(callId => { const evalCall = state.data.evaluationCalls[callId]; return [ @@ -157,5 +378,7 @@ const useNormalizedPlotDataFromMetrics = ( ]; }) ); + const allMetricNames = new Set(normalizedMetrics.map(m => m.metricLabel)); + return {radarData, allMetricNames}; }, [callIds, compositeMetrics, state.data.evaluationCalls]); };