From b8d74171529f83b5d371795ebc3cc8a7e2f27515 Mon Sep 17 00:00:00 2001 From: gtarpenning Date: Tue, 8 Oct 2024 11:26:59 -0700 Subject: [PATCH] chore: add column selection in eval summary --- .../pages/CompareEvaluationsPage/ecpState.ts | 3 + .../SummaryPlotsSection.tsx | 219 +++++++++++++++++- 2 files changed, 216 insertions(+), 6 deletions(-) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpState.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpState.ts index ea28cc417e8..dfb0da6560b 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpState.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpState.ts @@ -24,6 +24,8 @@ export type EvaluationComparisonState = { comparisonDimensions?: ComparisonDimensionsType; // The current digest which is in view selectedInputDigest?: string; + // Selected metrics to compare + selectedMetrics?: string[]; }; export type ComparisonDimensionsType = Array<{ @@ -93,6 +95,7 @@ export const useEvaluationComparisonState = ( baselineEvaluationCallId ?? evaluationCallIds[0], comparisonDimensions: newComparisonDimensions, selectedInputDigest, + selectedMetrics: newComparisonDimensions.map(dim => dim.metricId), }, }; }, [ diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/SummaryPlotsSection.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/SummaryPlotsSection.tsx index 4e685d14f93..4f6c8595267 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/SummaryPlotsSection.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/SummaryPlotsSection.tsx @@ -1,5 +1,16 @@ import {Box} from '@material-ui/core'; -import React, {useMemo} from 'react'; +import {Popover} from '@mui/material'; +import {Switch} from '@wandb/weave/components'; +import {Button} from '@wandb/weave/components/Button'; +import { + DraggableGrow, + DraggableHandle, +} from '@wandb/weave/components/DraggablePopups'; +import {TextField} from '@wandb/weave/components/Form/TextField'; +import {Tailwind} from '@wandb/weave/components/Tailwind'; +import {maybePluralize} from '@wandb/weave/core/util/string'; +import classNames from 'classnames'; +import React, {useMemo, useRef, useState} from 'react'; import {buildCompositeMetricsMap} from '../../compositeMetricsUtil'; import { @@ -25,7 +36,33 @@ import {PlotlyRadarPlot, RadarPlotData} from './PlotlyRadarPlot'; export const SummaryPlots: React.FC<{ state: EvaluationComparisonState; }> = props => { - const plotlyRadarData = useNormalizedPlotDataFromMetrics(props.state); + const {radarData, allMetricNames} = useNormalizedPlotDataFromMetrics( + props.state + ); + const [selectedMetrics, setSelectedMetrics] = useState( + Array.from(allMetricNames) + ); + + // filter down the plotlyRadarData to only include the selected metrics + const filteredPlotlyRadarData = useMemo(() => { + const filteredData: RadarPlotData = {}; + for (const [callId, metricBin] of Object.entries(radarData)) { + const metrics: {[metric: string]: number} = {}; + for (const [metric, value] of Object.entries(metricBin.metrics)) { + if (selectedMetrics.includes(metric)) { + metrics[metric] = value; + } + } + if (Object.keys(metrics).length > 0) { + filteredData[callId] = { + metrics: metrics, + name: metricBin.name, + color: metricBin.color, + }; + } + } + return filteredData; + }, [radarData, selectedMetrics]); return ( Summary Metrics + +
+
Configure displayed metrics
+ +
+
- + - +
); }; +const MetricsSelector: React.FC<{ + setSelectedMetrics: (metrics: string[]) => void; + selectedMetrics: string[]; + allMetrics: string[]; +}> = ({setSelectedMetrics, selectedMetrics, allMetrics}) => { + const [search, setSearch] = useState(''); + + const ref = useRef(null); + const [anchorEl, setAnchorEl] = useState(null); + const onClick = (event: React.MouseEvent) => { + setAnchorEl(anchorEl ? null : ref.current); + setSearch(''); + }; + const open = Boolean(anchorEl); + const id = open ? 'simple-popper' : undefined; + + const filteredCols = search + ? allMetrics.filter(col => col.toLowerCase().includes(search.toLowerCase())) + : allMetrics; + + const numHidden = allMetrics.length - selectedMetrics.length; + const buttonSuffix = search ? `(${filteredCols.length})` : 'all'; + + return ( + <> + + +
+ +
+ + + + + ); +}; + const normalizeValues = (values: Array): number[] => { // find the max value // find the power of 2 that is greater than the max value @@ -93,7 +298,7 @@ const normalizeValues = (values: Array): number[] => { const useNormalizedPlotDataFromMetrics = ( state: EvaluationComparisonState -): RadarPlotData => { +): {radarData: RadarPlotData; allMetricNames: Set} => { const compositeMetrics = useMemo(() => { return buildCompositeMetricsMap(state.data, 'summary'); }, [state]); @@ -137,7 +342,7 @@ const useNormalizedPlotDataFromMetrics = ( evalScores, }; }); - return Object.fromEntries( + const radarData = Object.fromEntries( callIds.map(callId => { const evalCall = state.data.evaluationCalls[callId]; return [ @@ -157,5 +362,7 @@ const useNormalizedPlotDataFromMetrics = ( ]; }) ); + const allMetricNames = new Set(normalizedMetrics.map(m => m.metricLabel)); + return {radarData, allMetricNames}; }, [callIds, compositeMetrics, state.data.evaluationCalls]); };