Skip to content

Commit

Permalink
fix(ui): Evaluations -- normalize radar data, show real values for ba…
Browse files Browse the repository at this point in the history
…r charts (#3199)
  • Loading branch information
gtarpenning authored Dec 11, 2024
1 parent 4e1e2c9 commit bf39669
Showing 1 changed file with 50 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ export const SummaryPlots: React.FC<{
state: EvaluationComparisonState;
setSelectedMetrics: (newModel: Record<string, boolean>) => void;
}> = ({state, setSelectedMetrics}) => {
const {radarData, allMetricNames} = useNormalizedPlotDataFromMetrics(state);
const {radarData, allMetricNames} = usePlotDataFromMetrics(state);
const {selectedMetrics} = state;

// Initialize selectedMetrics if null
Expand Down Expand Up @@ -237,11 +237,10 @@ const useFilteredData = (
return data;
}, [radarData, selectedMetrics]);

function getMetricValuesFromRadarData(radarData: RadarPlotData): {
function getMetricValuesMap(radarData: RadarPlotData): {
[metric: string]: number[];
} {
const metricValues: {[metric: string]: number[]} = {};
// Gather all values for each metric
Object.values(radarData).forEach(callData => {
Object.entries(callData.metrics).forEach(([metric, value]) => {
if (!metricValues[metric]) {
Expand All @@ -253,37 +252,54 @@ function getMetricValuesFromRadarData(radarData: RadarPlotData): {
return metricValues;
}

function getMetricMinsFromRadarData(radarData: RadarPlotData): {
[metric: string]: number;
function normalizeMetricValues(values: number[]): {
normalizedValues: number[];
normalizer: number;
} {
const metricValues = getMetricValuesFromRadarData(radarData);
const metricMins: {[metric: string]: number} = {};
Object.entries(metricValues).forEach(([metric, values]) => {
metricMins[metric] = Math.min(...values);
});
return metricMins;
const min = Math.min(...values);
const max = Math.max(...values);

if (min === max) {
return {
normalizedValues: values.map(() => 0.5),
normalizer: 1,
};
}

// Handle negative values by shifting
const shiftedValues = min < 0 ? values.map(v => v - min) : values;
const maxValue = min < 0 ? max - min : max;

const maxPower = Math.ceil(Math.log2(maxValue));
const normalizer = Math.pow(2, maxPower);

return {
normalizedValues: shiftedValues.map(v => v / normalizer),
normalizer,
};
}

function normalizeDataForRadarPlot(radarData: RadarPlotData): RadarPlotData {
const metricMins = getMetricMinsFromRadarData(radarData);
function normalizeDataForRadarPlot(
radarDataOriginal: RadarPlotData
): RadarPlotData {
const radarData = Object.fromEntries(
Object.entries(radarDataOriginal).map(([callId, callData]) => [
callId,
{...callData, metrics: {...callData.metrics}},
])
);

const normalizedData: RadarPlotData = {};
Object.entries(radarData).forEach(([callId, callData]) => {
normalizedData[callId] = {
name: callData.name,
color: callData.color,
metrics: {},
};
const metricValues = getMetricValuesMap(radarData);

Object.entries(callData.metrics).forEach(([metric, value]) => {
const min = metricMins[metric];
// Only shift values if there are negative values
const normalizedValue = min < 0 ? value - min : value;
normalizedData[callId].metrics[metric] = normalizedValue;
// Normalize each metric independently
Object.entries(metricValues).forEach(([metric, values]) => {
const {normalizedValues} = normalizeMetricValues(values);
Object.values(radarData).forEach((callData, index) => {
callData.metrics[metric] = normalizedValues[index];
});
});

return normalizedData;
return radarData;
}

const useBarPlotData = (filteredData: RadarPlotData) =>
Expand Down Expand Up @@ -317,7 +333,9 @@ const useBarPlotData = (filteredData: RadarPlotData) =>
type: 'bar',
y: metricBin.values,
x: metricBin.callIds,
text: metricBin.values.map(value => value.toFixed(3)),
text: metricBin.values.map(value =>
Number.isInteger(value) ? value.toString() : value.toFixed(3)
),
textposition: 'outside',
textfont: {size: 14, color: 'black'},
name: metric,
Expand Down Expand Up @@ -408,16 +426,7 @@ const usePaginatedPlots = (
return {plotsToShow, totalPlots, startIndex, endIndex, totalPages};
};

function normalizeValues(values: Array<number | undefined>): number[] {
// find the max value
// find the power of 2 that is greater than the max value
// divide all values by that power of 2
const maxVal = Math.max(...(values.filter(v => v !== undefined) as number[]));
const maxPower = Math.ceil(Math.log2(maxVal));
return values.map(val => (val ? val / 2 ** maxPower : 0));
}

const useNormalizedPlotDataFromMetrics = (
const usePlotDataFromMetrics = (
state: EvaluationComparisonState
): {radarData: RadarPlotData; allMetricNames: Set<string>} => {
const compositeMetrics = useMemo(() => {
Expand All @@ -428,7 +437,7 @@ const useNormalizedPlotDataFromMetrics = (
}, [state]);

return useMemo(() => {
const normalizedMetrics = Object.values(compositeMetrics)
const metrics = Object.values(compositeMetrics)
.map(scoreGroup => Object.values(scoreGroup.metrics))
.flat()
.map(metric => {
Expand All @@ -449,11 +458,8 @@ const useNormalizedPlotDataFromMetrics = (
return val;
}
});
const normalizedValues = normalizeValues(values);
const evalScores: {[evalCallId: string]: number | undefined} =
Object.fromEntries(
callIds.map((key, i) => [key, normalizedValues[i]])
);
Object.fromEntries(callIds.map((key, i) => [key, values[i]]));

const metricLabel = flattenedDimensionPath(
Object.values(metric.scorerRefs)[0].metric
Expand All @@ -472,7 +478,7 @@ const useNormalizedPlotDataFromMetrics = (
name: evalCall.name,
color: evalCall.color,
metrics: Object.fromEntries(
normalizedMetrics.map(metric => {
metrics.map(metric => {
return [
metric.metricLabel,
metric.evalScores[evalCall.callId] ?? 0,
Expand All @@ -483,7 +489,7 @@ const useNormalizedPlotDataFromMetrics = (
];
})
);
const allMetricNames = new Set(normalizedMetrics.map(m => m.metricLabel));
const allMetricNames = new Set(metrics.map(m => m.metricLabel));
return {radarData, allMetricNames};
}, [callIds, compositeMetrics, state.data.evaluationCalls]);
};

0 comments on commit bf39669

Please sign in to comment.