diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/SummaryPlotsSection.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/SummaryPlotsSection.tsx index c23ffcce04d..a66fe24f57f 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/SummaryPlotsSection.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/SummaryPlotsSection.tsx @@ -29,7 +29,7 @@ export const SummaryPlots: React.FC<{ state: EvaluationComparisonState; setSelectedMetrics: (newModel: Record) => void; }> = ({state, setSelectedMetrics}) => { - const {radarData, allMetricNames} = useNormalizedPlotDataFromMetrics(state); + const {radarData, allMetricNames} = usePlotDataFromMetrics(state); const {selectedMetrics} = state; // Initialize selectedMetrics if null @@ -237,11 +237,10 @@ const useFilteredData = ( return data; }, [radarData, selectedMetrics]); -function getMetricValuesFromRadarData(radarData: RadarPlotData): { +function getMetricValuesMap(radarData: RadarPlotData): { [metric: string]: number[]; } { const metricValues: {[metric: string]: number[]} = {}; - // Gather all values for each metric Object.values(radarData).forEach(callData => { Object.entries(callData.metrics).forEach(([metric, value]) => { if (!metricValues[metric]) { @@ -253,37 +252,54 @@ function getMetricValuesFromRadarData(radarData: RadarPlotData): { return metricValues; } -function getMetricMinsFromRadarData(radarData: RadarPlotData): { - [metric: string]: number; +function normalizeMetricValues(values: number[]): { + normalizedValues: number[]; + normalizer: number; } { - const metricValues = getMetricValuesFromRadarData(radarData); - const metricMins: {[metric: string]: number} = {}; - Object.entries(metricValues).forEach(([metric, values]) => { - metricMins[metric] = Math.min(...values); - }); - return metricMins; + const min = Math.min(...values); + const max = Math.max(...values); + + if (min === max) { + return { + normalizedValues: values.map(() => 0.5), + normalizer: 1, + }; + } + + // Handle negative values by shifting + const shiftedValues = min < 0 ? values.map(v => v - min) : values; + const maxValue = min < 0 ? max - min : max; + + const maxPower = Math.ceil(Math.log2(maxValue)); + const normalizer = Math.pow(2, maxPower); + + return { + normalizedValues: shiftedValues.map(v => v / normalizer), + normalizer, + }; } -function normalizeDataForRadarPlot(radarData: RadarPlotData): RadarPlotData { - const metricMins = getMetricMinsFromRadarData(radarData); +function normalizeDataForRadarPlot( + radarDataOriginal: RadarPlotData +): RadarPlotData { + const radarData = Object.fromEntries( + Object.entries(radarDataOriginal).map(([callId, callData]) => [ + callId, + {...callData, metrics: {...callData.metrics}}, + ]) + ); - const normalizedData: RadarPlotData = {}; - Object.entries(radarData).forEach(([callId, callData]) => { - normalizedData[callId] = { - name: callData.name, - color: callData.color, - metrics: {}, - }; + const metricValues = getMetricValuesMap(radarData); - Object.entries(callData.metrics).forEach(([metric, value]) => { - const min = metricMins[metric]; - // Only shift values if there are negative values - const normalizedValue = min < 0 ? value - min : value; - normalizedData[callId].metrics[metric] = normalizedValue; + // Normalize each metric independently + Object.entries(metricValues).forEach(([metric, values]) => { + const {normalizedValues} = normalizeMetricValues(values); + Object.values(radarData).forEach((callData, index) => { + callData.metrics[metric] = normalizedValues[index]; }); }); - return normalizedData; + return radarData; } const useBarPlotData = (filteredData: RadarPlotData) => @@ -408,16 +424,7 @@ const usePaginatedPlots = ( return {plotsToShow, totalPlots, startIndex, endIndex, totalPages}; }; -function normalizeValues(values: Array): number[] { - // find the max value - // find the power of 2 that is greater than the max value - // divide all values by that power of 2 - const maxVal = Math.max(...(values.filter(v => v !== undefined) as number[])); - const maxPower = Math.ceil(Math.log2(maxVal)); - return values.map(val => (val ? val / 2 ** maxPower : 0)); -} - -const useNormalizedPlotDataFromMetrics = ( +const usePlotDataFromMetrics = ( state: EvaluationComparisonState ): {radarData: RadarPlotData; allMetricNames: Set} => { const compositeMetrics = useMemo(() => { @@ -428,7 +435,7 @@ const useNormalizedPlotDataFromMetrics = ( }, [state]); return useMemo(() => { - const normalizedMetrics = Object.values(compositeMetrics) + const metrics = Object.values(compositeMetrics) .map(scoreGroup => Object.values(scoreGroup.metrics)) .flat() .map(metric => { @@ -449,11 +456,8 @@ const useNormalizedPlotDataFromMetrics = ( return val; } }); - const normalizedValues = normalizeValues(values); const evalScores: {[evalCallId: string]: number | undefined} = - Object.fromEntries( - callIds.map((key, i) => [key, normalizedValues[i]]) - ); + Object.fromEntries(callIds.map((key, i) => [key, values[i]])); const metricLabel = flattenedDimensionPath( Object.values(metric.scorerRefs)[0].metric @@ -472,7 +476,7 @@ const useNormalizedPlotDataFromMetrics = ( name: evalCall.name, color: evalCall.color, metrics: Object.fromEntries( - normalizedMetrics.map(metric => { + metrics.map(metric => { return [ metric.metricLabel, metric.evalScores[evalCall.callId] ?? 0, @@ -483,7 +487,7 @@ const useNormalizedPlotDataFromMetrics = ( ]; }) ); - const allMetricNames = new Set(normalizedMetrics.map(m => m.metricLabel)); + const allMetricNames = new Set(metrics.map(m => m.metricLabel)); return {radarData, allMetricNames}; }, [callIds, compositeMetrics, state.data.evaluationCalls]); };