Skip to content

Commit

Permalink
chore: add column selection in eval summary
Browse files Browse the repository at this point in the history
  • Loading branch information
gtarpenning committed Oct 8, 2024
1 parent 5cf5f0b commit b8d7417
Show file tree
Hide file tree
Showing 2 changed files with 216 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ export type EvaluationComparisonState = {
comparisonDimensions?: ComparisonDimensionsType;
// The current digest which is in view
selectedInputDigest?: string;
// Selected metrics to compare
selectedMetrics?: string[];
};

export type ComparisonDimensionsType = Array<{
Expand Down Expand Up @@ -93,6 +95,7 @@ export const useEvaluationComparisonState = (
baselineEvaluationCallId ?? evaluationCallIds[0],
comparisonDimensions: newComparisonDimensions,
selectedInputDigest,
selectedMetrics: newComparisonDimensions.map(dim => dim.metricId),
},
};
}, [
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,16 @@
import {Box} from '@material-ui/core';
import React, {useMemo} from 'react';
import {Popover} from '@mui/material';
import {Switch} from '@wandb/weave/components';
import {Button} from '@wandb/weave/components/Button';
import {
DraggableGrow,
DraggableHandle,
} from '@wandb/weave/components/DraggablePopups';
import {TextField} from '@wandb/weave/components/Form/TextField';
import {Tailwind} from '@wandb/weave/components/Tailwind';
import {maybePluralize} from '@wandb/weave/core/util/string';
import classNames from 'classnames';
import React, {useMemo, useRef, useState} from 'react';

import {buildCompositeMetricsMap} from '../../compositeMetricsUtil';
import {
Expand All @@ -25,7 +36,33 @@ import {PlotlyRadarPlot, RadarPlotData} from './PlotlyRadarPlot';
export const SummaryPlots: React.FC<{
state: EvaluationComparisonState;
}> = props => {
const plotlyRadarData = useNormalizedPlotDataFromMetrics(props.state);
const {radarData, allMetricNames} = useNormalizedPlotDataFromMetrics(
props.state
);
const [selectedMetrics, setSelectedMetrics] = useState<string[]>(
Array.from(allMetricNames)
);

// filter down the plotlyRadarData to only include the selected metrics
const filteredPlotlyRadarData = useMemo(() => {
const filteredData: RadarPlotData = {};
for (const [callId, metricBin] of Object.entries(radarData)) {
const metrics: {[metric: string]: number} = {};
for (const [metric, value] of Object.entries(metricBin.metrics)) {
if (selectedMetrics.includes(metric)) {
metrics[metric] = value;
}
}
if (Object.keys(metrics).length > 0) {
filteredData[callId] = {
metrics: metrics,
name: metricBin.name,
color: metricBin.color,
};
}
}
return filteredData;
}, [radarData, selectedMetrics]);

return (
<VerticalBox
Expand All @@ -48,6 +85,16 @@ export const SummaryPlots: React.FC<{
}}>
Summary Metrics
</Box>
<Box sx={{marginLeft: 'auto'}}>
<div style={{display: 'flex', alignItems: 'center'}}>
<div style={{marginRight: '4px'}}>Configure displayed metrics</div>
<MetricsSelector
setSelectedMetrics={setSelectedMetrics}
selectedMetrics={selectedMetrics}
allMetrics={Array.from(allMetricNames)}
/>
</div>
</Box>
</HorizontalBox>
<HorizontalBox
sx={{
Expand All @@ -63,7 +110,10 @@ export const SummaryPlots: React.FC<{
alignContent: 'center',
width: PLOT_HEIGHT,
}}>
<PlotlyRadarPlot height={PLOT_HEIGHT} data={plotlyRadarData} />
<PlotlyRadarPlot
height={PLOT_HEIGHT}
data={filteredPlotlyRadarData}
/>
</Box>
<Box
sx={{
Expand All @@ -75,13 +125,168 @@ export const SummaryPlots: React.FC<{
padding: PLOT_PADDING,
width: PLOT_HEIGHT,
}}>
<PlotlyBarPlot height={PLOT_HEIGHT} data={plotlyRadarData} />
<PlotlyBarPlot height={PLOT_HEIGHT} data={filteredPlotlyRadarData} />
</Box>
</HorizontalBox>
</VerticalBox>
);
};

const MetricsSelector: React.FC<{
setSelectedMetrics: (metrics: string[]) => void;
selectedMetrics: string[];
allMetrics: string[];
}> = ({setSelectedMetrics, selectedMetrics, allMetrics}) => {
const [search, setSearch] = useState('');

const ref = useRef<HTMLDivElement>(null);
const [anchorEl, setAnchorEl] = useState<null | HTMLElement>(null);
const onClick = (event: React.MouseEvent<HTMLElement>) => {
setAnchorEl(anchorEl ? null : ref.current);
setSearch('');
};
const open = Boolean(anchorEl);
const id = open ? 'simple-popper' : undefined;

const filteredCols = search
? allMetrics.filter(col => col.toLowerCase().includes(search.toLowerCase()))
: allMetrics;

const numHidden = allMetrics.length - selectedMetrics.length;
const buttonSuffix = search ? `(${filteredCols.length})` : 'all';

return (
<>
<span ref={ref}>
<Button
variant="ghost"
icon="column"
tooltip="Manage metrics"
onClick={onClick}
/>
</span>
<Popover
id={id}
open={open}
anchorEl={anchorEl}
anchorOrigin={{
vertical: 'bottom',
horizontal: 'center',
}}
transformOrigin={{
vertical: 'top',
horizontal: 'center',
}}
slotProps={{
paper: {
sx: {
overflow: 'visible',
},
},
}}
onClose={() => setAnchorEl(null)}
TransitionComponent={DraggableGrow}>
<Tailwind>
<div className="min-w-[360px] p-12">
<DraggableHandle>
<div className="flex items-center pb-8">
<div className="flex-auto text-xl font-semibold">
Manage metrics
</div>
<div className="ml-16 text-moon-500">
{maybePluralize(numHidden, 'hidden column', 's')}
</div>
</div>
</DraggableHandle>
<div className="mb-8">
<TextField
placeholder="Filter columns"
autoFocus
value={search}
onChange={setSearch}
/>
</div>
<div className="max-h-[300px] overflow-auto">
{Array.from(allMetrics).map((metric: string) => {
const value = metric;
const idSwitch = `toggle-vis_${value}`;
const checked = selectedMetrics.includes(metric);
const label = metric;
const disabled = false;
if (
search &&
!label.toLowerCase().includes(search.toLowerCase())
) {
return null;
}
return (
<div key={value}>
<div
className={classNames(
'flex items-center py-2',
disabled ? 'opacity-40' : ''
)}>
<Switch.Root
id={idSwitch}
size="small"
checked={checked}
onCheckedChange={isOn => {
setSelectedMetrics(
isOn
? [...selectedMetrics, metric]
: selectedMetrics.filter(m => m !== metric)
);
}}
disabled={disabled}>
<Switch.Thumb size="small" checked={checked} />
</Switch.Root>
<label
htmlFor={idSwitch}
className={classNames(
'ml-6',
disabled ? '' : 'cursor-pointer'
)}>
{label}
</label>
</div>
</div>
);
})}
</div>
<div className="mt-8 flex items-center">
<Button
size="small"
variant="quiet"
icon="hide-hidden"
disabled={filteredCols.length === 0}
onClick={() =>
setSelectedMetrics(
selectedMetrics.filter(m => !filteredCols.includes(m))
)
}>
{`Hide ${buttonSuffix}`}
</Button>
<div className="flex-auto" />
<Button
size="small"
variant="quiet"
icon="show-visible"
disabled={filteredCols.length === 0}
onClick={() =>
setSelectedMetrics(
Array.from(new Set([...selectedMetrics, ...filteredCols]))
)
}>
{`Show ${buttonSuffix}`}
</Button>
</div>
</div>
</Tailwind>
</Popover>
</>
);
};

const normalizeValues = (values: Array<number | undefined>): number[] => {
// find the max value
// find the power of 2 that is greater than the max value
Expand All @@ -93,7 +298,7 @@ const normalizeValues = (values: Array<number | undefined>): number[] => {

const useNormalizedPlotDataFromMetrics = (
state: EvaluationComparisonState
): RadarPlotData => {
): {radarData: RadarPlotData; allMetricNames: Set<string>} => {
const compositeMetrics = useMemo(() => {
return buildCompositeMetricsMap(state.data, 'summary');
}, [state]);
Expand Down Expand Up @@ -137,7 +342,7 @@ const useNormalizedPlotDataFromMetrics = (
evalScores,
};
});
return Object.fromEntries(
const radarData = Object.fromEntries(
callIds.map(callId => {
const evalCall = state.data.evaluationCalls[callId];
return [
Expand All @@ -157,5 +362,7 @@ const useNormalizedPlotDataFromMetrics = (
];
})
);
const allMetricNames = new Set(normalizedMetrics.map(m => m.metricLabel));
return {radarData, allMetricNames};
}, [callIds, compositeMetrics, state.data.evaluationCalls]);
};

0 comments on commit b8d7417

Please sign in to comment.