Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: add column selection in eval comparison page #2632

Merged
merged 9 commits into from
Oct 10, 2024
16 changes: 16 additions & 0 deletions weave-js/src/components/PagePanelComponents/Home/Browse3.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -928,6 +928,7 @@ const CompareEvaluationsBinding = () => {
const evaluationCallIds = useMemo(() => {
return JSON.parse(query.evaluationCallIds);
}, [query.evaluationCallIds]);

const onEvaluationCallIdsUpdate = useCallback(
(newEvaluationCallIds: string[]) => {
const newQuery = new URLSearchParams(location.search);
Expand All @@ -936,12 +937,27 @@ const CompareEvaluationsBinding = () => {
},
[history, location.search]
);

const selectedMetrics: Record<string, boolean> | null = useMemo(() => {
try {
return JSON.parse(query.metrics);
} catch (e) {
return null;
}
}, [query.metrics]);
const setSelectedMetrics = (newModel: Record<string, boolean>) => {
const newQuery = new URLSearchParams(location.search);
newQuery.set('metrics', JSON.stringify(newModel));
history.push({search: newQuery.toString()});
};
return (
<CompareEvaluationsPage
entity={entity}
project={project}
evaluationCallIds={evaluationCallIds}
onEvaluationCallIdsUpdate={onEvaluationCallIdsUpdate}
selectedMetrics={selectedMetrics}
setSelectedMetrics={setSelectedMetrics}
/>
);
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ const useCallTabs = (call: CallSchema) => {
entity={call.entity}
project={call.project}
evaluationCallIds={[call.callId]}
// Dont persist metric selection in the URL
selectedMetrics={{}}
setSelectedMetrics={() => {}}
// Dont persist changes to evaluationCallIds in the URL
onEvaluationCallIdsUpdate={() => {}}
/>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ type CompareEvaluationsPageProps = {
project: string;
evaluationCallIds: string[];
onEvaluationCallIdsUpdate: (newEvaluationCallIds: string[]) => void;
selectedMetrics: Record<string, boolean> | null;
setSelectedMetrics: (newModel: Record<string, boolean>) => void;
};

export const CompareEvaluationsPage: React.FC<
Expand All @@ -57,6 +59,8 @@ export const CompareEvaluationsPage: React.FC<
project={props.project}
evaluationCallIds={props.evaluationCallIds}
onEvaluationCallIdsUpdate={props.onEvaluationCallIdsUpdate}
selectedMetrics={props.selectedMetrics}
setSelectedMetrics={props.setSelectedMetrics}
/>
),
},
Expand Down Expand Up @@ -112,6 +116,8 @@ export const CompareEvaluationsPageContent: React.FC<
<CompareEvaluationsProvider
entity={props.entity}
project={props.project}
selectedMetrics={props.selectedMetrics}
setSelectedMetrics={props.setSelectedMetrics}
initialEvaluationCallIds={props.evaluationCallIds}
baselineEvaluationCallId={baselineEvaluationCallId ?? undefined}
comparisonDimensions={comparisonDimensions ?? undefined}
Expand Down Expand Up @@ -179,7 +185,7 @@ const ReturnToEvaluationsButton: FC<{entity: string; project: string}> = ({
const CompareEvaluationsPageInner: React.FC<{
height: number;
}> = props => {
const {state} = useCompareEvaluationsState();
const {state, setSelectedMetrics} = useCompareEvaluationsState();
const showExampleFilter =
Object.keys(state.data.evaluationCalls).length === 2;
const showExamples = Object.keys(state.data.resultRows).length > 0;
Expand All @@ -200,7 +206,7 @@ const CompareEvaluationsPageInner: React.FC<{
evaluationCalls={Object.values(state.data.evaluationCalls)}
/>
<ComparisonDefinitionSection state={state} />
<SummaryPlots state={state} />
<SummaryPlots state={state} setSelectedMetrics={setSelectedMetrics} />
<ScorecardSection state={state} />
{showExamples ? (
<>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ const CompareEvaluationsContext = React.createContext<{
React.SetStateAction<ComparisonDimensionsType | null>
>;
setSelectedInputDigest: React.Dispatch<React.SetStateAction<string | null>>;
setSelectedMetrics: (newModel: Record<string, boolean>) => void;
addEvaluationCall: (newCallId: string) => void;
removeEvaluationCall: (callId: string) => void;
} | null>(null);
Expand All @@ -31,6 +32,9 @@ export const useCompareEvaluationsState = () => {
export const CompareEvaluationsProvider: React.FC<{
entity: string;
project: string;
selectedMetrics: Record<string, boolean> | null;
setSelectedMetrics: (newModel: Record<string, boolean>) => void;

initialEvaluationCallIds: string[];
onEvaluationCallIdsUpdate: (newEvaluationCallIds: string[]) => void;
setBaselineEvaluationCallId: React.Dispatch<
Expand All @@ -46,13 +50,15 @@ export const CompareEvaluationsProvider: React.FC<{
}> = ({
entity,
project,
selectedMetrics,
setSelectedMetrics,

initialEvaluationCallIds,
onEvaluationCallIdsUpdate,
setBaselineEvaluationCallId,
setComparisonDimensions,

setSelectedInputDigest,

baselineEvaluationCallId,
comparisonDimensions,
selectedInputDigest,
Expand All @@ -67,7 +73,8 @@ export const CompareEvaluationsProvider: React.FC<{
evaluationCallIds,
baselineEvaluationCallId,
comparisonDimensions,
selectedInputDigest
selectedInputDigest,
selectedMetrics ?? undefined
);

const value = useMemo(() => {
Expand All @@ -79,6 +86,7 @@ export const CompareEvaluationsProvider: React.FC<{
setBaselineEvaluationCallId,
setComparisonDimensions,
setSelectedInputDigest,
setSelectedMetrics,
addEvaluationCall: (newCallId: string) => {
const newEvaluationCallIds = [...evaluationCallIds, newCallId];
setEvaluationCallIds(newEvaluationCallIds);
Expand All @@ -101,6 +109,7 @@ export const CompareEvaluationsProvider: React.FC<{
setBaselineEvaluationCallId,
setComparisonDimensions,
setSelectedInputDigest,
setSelectedMetrics,
]);

if (!value) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ export type CompositeSummaryMetricGroupForKeyPath = {
*/
export const buildCompositeMetricsMap = (
data: EvaluationComparisonData,
mType: MetricType
mType: MetricType,
selectedMetrics: Record<string, boolean> | undefined = undefined
): CompositeScoreMetrics => {
const composite: CompositeScoreMetrics = {};

Expand All @@ -93,6 +94,12 @@ export const buildCompositeMetricsMap = (
Object.entries(metricDefinitionMap).forEach(([metricId, metric]) => {
const groupName = groupNameForMetric(metric);
const ref = refForMetric(metric);
const keyPath = flattenedDimensionPath(metric);

if (selectedMetrics && !selectedMetrics[keyPath]) {
// Skip metrics that are not in the selectedMetrics map
return;
}

if (!composite[groupName]) {
composite[groupName] = {
Expand All @@ -105,8 +112,6 @@ export const buildCompositeMetricsMap = (
metricGroup.scorerRefs.push(ref);
}

const keyPath = flattenedDimensionPath(metric);

if (!metricGroup.metrics[keyPath]) {
metricGroup.metrics[keyPath] = {
scorerAgnosticMetricDef: _.omit(metric, 'scorerOpOrObjRef'),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ export type EvaluationComparisonState = {
comparisonDimensions?: ComparisonDimensionsType;
// The current digest which is in view
selectedInputDigest?: string;
// The selected metrics to display
selectedMetrics?: Record<string, boolean>;
};

export type ComparisonDimensionsType = Array<{
Expand All @@ -43,7 +45,8 @@ export const useEvaluationComparisonState = (
evaluationCallIds: string[],
baselineEvaluationCallId?: string,
comparisonDimensions?: ComparisonDimensionsType,
selectedInputDigest?: string
selectedInputDigest?: string,
selectedMetrics?: Record<string, boolean>
): Loadable<EvaluationComparisonState> => {
const data = useEvaluationComparisonData(entity, project, evaluationCallIds);

Expand Down Expand Up @@ -93,6 +96,7 @@ export const useEvaluationComparisonState = (
baselineEvaluationCallId ?? evaluationCallIds[0],
comparisonDimensions: newComparisonDimensions,
selectedInputDigest,
selectedMetrics,
},
};
}, [
Expand All @@ -102,6 +106,7 @@ export const useEvaluationComparisonState = (
evaluationCallIds,
comparisonDimensions,
selectedInputDigest,
selectedMetrics,
]);

return value;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,13 @@ export const ExampleCompareSection: React.FC<{
const {ref1, ref2} = useLinkHorizontalScroll();

const compositeScoreMetrics = useMemo(
() => buildCompositeMetricsMap(props.state.data, 'score'),
[props.state.data]
() =>
buildCompositeMetricsMap(
props.state.data,
'score',
props.state.selectedMetrics
),
[props.state.data, props.state.selectedMetrics]
);

if (target == null) {
Expand All @@ -252,9 +257,13 @@ export const ExampleCompareSection: React.FC<{
.length;
});
const numEvals = numTrials.length;
// Get derived scores, then filter out any not in the selected metrics
const derivedScores = Object.values(
getMetricIds(props.state.data, 'score', 'derived')
).filter(
score => props.state.selectedMetrics?.[flattenedDimensionPath(score)]
);

const numMetricScorers = metricGroupNames.length;
const numDerivedScores = derivedScores.length;
const numMetricsPerScorer = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ const rowIsSelected = (
export const useFilteredAggregateRows = (state: EvaluationComparisonState) => {
const leafDims = useMemo(() => getOrderedCallIds(state), [state]);
const compositeMetricsMap = useMemo(
() => buildCompositeMetricsMap(state.data, 'score'),
[state.data]
() => buildCompositeMetricsMap(state.data, 'score', state.selectedMetrics),
[state.data, state.selectedMetrics]
);

const flattenedRows = useMemo(() => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,11 @@ export const ScorecardSection: React.FC<{
const [diffOnly, setDiffOnly] = React.useState(true);

const compositeSummaryMetrics = useMemo(() => {
return buildCompositeMetricsMap(props.state.data, 'summary');
return buildCompositeMetricsMap(
props.state.data,
'summary',
props.state.selectedMetrics
);
}, [props.state]);

const onCallClick = usePeekCall(
Expand Down
Loading
Loading