Skip to content

Commit

Permalink
Human: Ok, getting to a checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
tssweeney committed Oct 8, 2024
1 parent c2714d5 commit 91f56ea
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 72 deletions.
Original file line number Diff line number Diff line change
@@ -1,49 +1,62 @@
import {Box} from '@mui/material';
import {
GridColDef,
GridColumnGroup,
GridColumnGroupingModel,
GridColumnNode,
GridPaginationModel,
GridRenderCellParams,
GridSortDirection,
} from '@mui/x-data-grid-pro';
import React, {useCallback, useMemo, useState} from 'react';

import {parseRefMaybe, SmallRef} from '../../../Browse2/SmallRef';
import {StyledDataGrid} from '../../StyledDataGrid';
import {PaginationButtons} from '../CallsPage/CallsTableButtons';
import {buildTree} from '../common/tabularListViews/buildTree';
import {LeaderboardData} from './hooks';

interface LeaderboardGridProps {
data: LeaderboardData;
loading: boolean;
onCellClick: (modelName: string, metricName: string, score: number) => void;
}

interface RowData {
type RowData = {
id: number;
model: string;
[key: string]: number | string;
}
} & LeaderboardData['scores'][string];

export const LeaderboardGrid: React.FC<LeaderboardGridProps> = ({
data,
loading,
onCellClick,
}) => {
const [paginationModel, setPaginationModel] = useState<GridPaginationModel>({
pageSize: 25,
page: 0,
});

const orderedMetrics = useMemo(() => {
return Object.keys(data.metrics);
// return Object.keys(data.metrics).sort((a, b) => {
// return a.localeCompare(b);
// });
}, [data.metrics]);

const metricRanges = useMemo(() => {
const ranges: {[key: string]: {min: number; max: number}} = {};
Object.keys(data.metrics).forEach(metric => {
orderedMetrics.forEach(metric => {
const scores = data.models
.map(model => data.scores?.[model]?.[metric])
.map(model => data.scores?.[model]?.[metric]?.value)
.filter(score => score !== undefined);
ranges[metric] = {
min: Math.min(...scores),
max: Math.max(...scores),
};
});
return ranges;
}, [data]);
}, [data.models, data.scores, orderedMetrics]);

const getColorForScore = useCallback(
(metric: string, score: number | undefined) => {
Expand All @@ -57,14 +70,74 @@ export const LeaderboardGrid: React.FC<LeaderboardGridProps> = ({
[metricRanges]
);

const columns: GridColDef[] = useMemo(
const rows: RowData[] = useMemo(
() =>
data.models.map((model, index) => ({
id: index,
model,
...data.scores[model],
})) as RowData[],
[data.models, data.scores]
);

const columns: Array<GridColDef<RowData>> = useMemo(
() => [
{field: 'model', headerName: 'Model', width: 200, flex: 1},
...Object.keys(data.metrics).map(metric => ({
{
field: 'model',
headerName: 'Model',
minWidth: 200,
flex: 1,
renderCell: (params: GridRenderCellParams) => {
const modelRef = parseRefMaybe(params.value);
if (modelRef) {
return (
<div
style={{
width: '100%',
height: '100%',
alignContent: 'center',
display: 'flex',
alignItems: 'center',
justifyContent: 'center',
lineHeight: '20px',
}}>
<SmallRef objRef={modelRef} />
</div>
);
}
return <div>{params.value}</div>;
},
},
...orderedMetrics.map(metric => ({
field: metric,
headerName: metric,
width: 130,
headerName: metric.split('.').pop(),
minWidth: 130,
flex: 1,
valueGetter: (value: RowData) => {
return value?.value;
},
getSortComparator: (dir: GridSortDirection) => (a: any, b: any) => {
const aValue = a;
const bValue = b;
if (aValue == null && bValue == null) {
return 0;
}
// Ignoring direction here allows nulls to always sort to the end
if (aValue == null) {
return 1;
}
if (bValue == null) {
return -1;
}
if (typeof aValue === 'number' && typeof bValue === 'number') {
if (dir === 'asc') {
return aValue - bValue;
} else {
return bValue - aValue;
}
}
return aValue.localeCompare(bValue);
},
renderCell: (params: GridRenderCellParams) => {
let inner = params.value;
if (typeof inner === 'number') {
Expand Down Expand Up @@ -100,32 +173,43 @@ export const LeaderboardGrid: React.FC<LeaderboardGridProps> = ({
},
})),
],
[data.metrics, getColorForScore, onCellClick]
[getColorForScore, onCellClick, orderedMetrics]
);

const rows: RowData[] = useMemo(
() =>
data.models.map((model, index) => ({
id: index,
model,
...data.scores[model],
})),
[data.models, data.scores]
);
const tree = buildTree([...Object.keys(data.metrics)]);
let groupingModel: GridColumnGroupingModel = tree.children.filter(
c => 'groupId' in c
) as GridColumnGroup[];
groupingModel = walkGroupingModel(groupingModel, node => {
if ('groupId' in node) {
if (node.children.length === 1) {
if ('groupId' in node.children[0]) {
const currNode = node;
node = node.children[0];
node.headerName = currNode.headerName + '.' + node.headerName;
} else {
// pass
// node = node.children[0];
}
}
}
return node;
}) as GridColumnGroup[];
console.log(groupingModel);

const groupingModel: GridColumnGroupingModel = useMemo(
() => {
return [
{
groupId: 'metrics',
children: Object.keys(data.metrics).map(metric => ({
field: metric,
})),
},
];
},
[data.metrics]
);
// const groupingModel: GridColumnGroupingModel = useMemo(
// () => {
// return [
// {
// groupId: 'metrics',
// children: Object.keys(data.metrics).map(metric => ({
// field: metric,
// })),
// },
// ];
// },
// [data.metrics]
// );

return (
<Box
Expand All @@ -146,6 +230,9 @@ export const LeaderboardGrid: React.FC<LeaderboardGridProps> = ({
pageSizeOptions={[25]}
disableRowSelectionOnClick
hideFooterSelectedRowCount
disableMultipleColumnsSorting={false}
columnHeaderHeight={40}
loading={loading}
sx={{
borderRadius: 0,
'& .MuiDataGrid-footerContainer': {
Expand Down Expand Up @@ -174,3 +261,15 @@ export const LeaderboardGrid: React.FC<LeaderboardGridProps> = ({
</Box>
);
};

const walkGroupingModel = (
nodes: GridColumnNode[],
fn: (node: GridColumnNode) => GridColumnNode
) => {
return nodes.map(node => {
if ('children' in node) {
node.children = walkGroupingModel(node.children, fn);
}
return fn(node);
});
};
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import {Box} from '@mui/material';
import React, {useState} from 'react';
import {useHistory} from 'react-router-dom';

import {useWeaveflowRouteContext} from '../../context';
import {EditableMarkdown} from './EditableMarkdown';
import {useLeaderboardData} from './hooks';
import {LeaderboardGrid} from './LeaderboardGrid';

const USE_COMPARE_EVALUATIONS_PAGE = false;

type LeaderboardPageProps = {
entity: string;
project: string;
Expand All @@ -17,49 +21,41 @@ export const LeaderboardPage: React.FC<LeaderboardPageProps> = props => {
};

const DEFAULT_DESCRIPTION = `# Leaderboard`;
const EXAMPLE_DESCRIPTION = `# Welcome to the Leaderboard!
This leaderboard showcases the performance of various models across different metrics. Here's a quick guide:
## Metrics Explained
- **Accuracy**: Overall correctness of the model
- **F1 Score**: Balanced measure of precision and recall
- **Precision**: Ratio of true positives to all positive predictions
- **Recall**: Ratio of true positives to all actual positives
- **AUC-ROC**: Area Under the Receiver Operating Characteristic curve
## How to Interpret the Results
1. Higher scores are generally better for all metrics.
2. Look for models that perform well across *multiple* metrics.
3. Consider the trade-offs between different metrics based on your specific use case.
> Note: Click on any cell in the grid to get more detailed information about that specific score.
Happy analyzing!`;

export const LeaderboardPageContent: React.FC<LeaderboardPageProps> = props => {
const {entity, project} = props;
const [description, setDescription] = useState(EXAMPLE_DESCRIPTION);
const data = useLeaderboardData(entity, project);
const [description, setDescription] = useState("");
const {loading, data} = useLeaderboardData(entity, project);

// const setDescription = useCallback((newDescription: string) => {
// setDescriptionRaw(newDescription.trim());
// }, []);

const {peekingRouter} = useWeaveflowRouteContext();
const history = useHistory();

const handleCellClick = (
modelName: string,
metricName: string,
score: number
) => {
console.log(`Clicked on ${modelName} for ${metricName}: ${score}%`);
// TODO: Implement action on cell click
const sourceCallId = data.scores?.[modelName]?.[metricName]?.sourceCallId;
if (sourceCallId) {
let to: string;
if (USE_COMPARE_EVALUATIONS_PAGE) {
to = peekingRouter.compareEvaluationsUri(entity, project, [
sourceCallId,
]);
} else {
to = peekingRouter.callUIUrl(entity, project, '', sourceCallId, null);
}
history.push(to);
}
};

return (
<Box display="flex" flexDirection="column" height="100%">
<Box flexGrow={1} flexShrink={0} maxHeight="50%" overflow="auto">
<Box flexShrink={0} maxHeight="50%" overflow="auto">
<EditableMarkdown
value={description}
onChange={setDescription}
Expand All @@ -72,7 +68,11 @@ export const LeaderboardPageContent: React.FC<LeaderboardPageProps> = props => {
flexDirection="column"
overflow="hidden"
minHeight="50%">
<LeaderboardGrid data={data} onCellClick={handleCellClick} />
<LeaderboardGrid
loading={loading}
data={data}
onCellClick={handleCellClick}
/>
</Box>
</Box>
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,14 @@ export type LeaderboardData = {
};
};
models: string[];
scores: {[modelId: string]: {[metricId: string]: number}};
scores: {
[modelId: string]: {
[metricId: string]: {value: number; sourceCallId: string};
};
};
};

const useEvaluations = (
export const useLeaderboardData = (
entity: string,
project: string
): {loading: boolean; data: LeaderboardData} => {
Expand Down Expand Up @@ -129,7 +133,10 @@ const useEvaluations = (
if (!finalData.scores[modelName]) {
finalData.scores[modelName] = {};
}
finalData.scores[modelName][metricName] = score;
finalData.scores[modelName][metricName] = {
value: score,
sourceCallId: r.callId,
};
});
});

Expand All @@ -145,11 +152,3 @@ const useEvaluations = (

return results;
};

export const useLeaderboardData = (
entity: string,
project: string
): LeaderboardData => {
const {loading, data} = useEvaluations(entity, project);
return data;
};
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ export const CallLink: React.FC<{
}> = props => {
const history = useHistory();
const {peekingRouter} = useWeaveflowRouteContext();

const opName = opNiceName(props.opName);

// Custom logic to calculate path and tracetree here is not good. Shows
Expand Down

0 comments on commit 91f56ea

Please sign in to comment.