diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ComparisonDefinitionSection/ComparisonDefinitionSection.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ComparisonDefinitionSection/ComparisonDefinitionSection.tsx index b5c1a4bf96c..66ce56c4cb0 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ComparisonDefinitionSection/ComparisonDefinitionSection.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ComparisonDefinitionSection/ComparisonDefinitionSection.tsx @@ -105,7 +105,7 @@ const AddEvaluationButton: React.FC<{ ); const expandedRefCols = useMemo(() => new Set(), []); // Don't query for output here, re-queried in tsDataModelHooksEvaluationComparison.ts - const columns = useMemo(() => ['inputs'], []); + const columns = useMemo(() => ['inputs', 'display_name'], []); const calls = useCallsForQuery( props.state.data.entity, props.state.data.project, @@ -137,7 +137,7 @@ const AddEvaluationButton: React.FC<{ return; } - const filteredOptions = calls.result.filter(call => { + const filteredOptions = evalsNotComparing.filter(call => { if ( (call.displayName ?? call.spanName) .toLowerCase() diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/MetricsSelector.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/MetricsSelector.tsx new file mode 100644 index 00000000000..3e6dfdd30a1 --- /dev/null +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/MetricsSelector.tsx @@ -0,0 +1,173 @@ +import {Popover} from '@mui/material'; +import {Switch} from '@wandb/weave/components'; +import {Button} from '@wandb/weave/components/Button'; +import { + DraggableGrow, + DraggableHandle, +} from '@wandb/weave/components/DraggablePopups'; +import {TextField} from '@wandb/weave/components/Form/TextField'; +import {Tailwind} from '@wandb/weave/components/Tailwind'; +import {maybePluralize} from '@wandb/weave/core/util/string'; +import classNames from 'classnames'; +import React, {useRef, useState} from 'react'; + +export const MetricsSelector: React.FC<{ + setSelectedMetrics: (newModel: Record) => void; + selectedMetrics: Record | undefined; + allMetrics: string[]; +}> = ({setSelectedMetrics, selectedMetrics, allMetrics}) => { + const [search, setSearch] = useState(''); + + const ref = useRef(null); + const [anchorEl, setAnchorEl] = useState(null); + const onClick = (event: React.MouseEvent) => { + setAnchorEl(anchorEl ? null : ref.current); + setSearch(''); + }; + const open = Boolean(anchorEl); + const id = open ? 'simple-popper' : undefined; + + const filteredCols = search + ? allMetrics.filter(col => col.toLowerCase().includes(search.toLowerCase())) + : allMetrics; + + const shownMetrics = Object.values(selectedMetrics ?? {}).filter(Boolean); + + const numHidden = allMetrics.length - shownMetrics.length; + const buttonSuffix = search ? `(${filteredCols.length})` : 'all'; + + return ( + <> + + +
+ +
+ + + + + ); +}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/PlotlyBarPlot.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/PlotlyBarPlot.tsx index 9706ac09567..7942aea195e 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/PlotlyBarPlot.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/PlotlyBarPlot.tsx @@ -2,50 +2,54 @@ import * as Plotly from 'plotly.js'; import React, {useEffect, useMemo, useRef} from 'react'; import {PLOT_GRID_COLOR} from '../../ecpConstants'; -import {RadarPlotData} from './PlotlyRadarPlot'; export const PlotlyBarPlot: React.FC<{ height: number; - data: RadarPlotData; + yRange: [number, number]; + plotlyData: Plotly.Data; }> = props => { const divRef = useRef(null); - const plotlyData: Plotly.Data[] = useMemo(() => { - return Object.keys(props.data).map((key, i) => { - const {metrics, name, color} = props.data[key]; - return { - type: 'bar', - y: Object.values(metrics), - x: Object.keys(metrics), - name, - marker: {color}, - }; - }); - }, [props.data]); - const plotlyLayout: Partial = useMemo(() => { return { - height: props.height - 40, + height: props.height - 30, showlegend: false, margin: { - l: 0, + l: 20, r: 0, b: 20, - t: 0, - pad: 0, + t: 26, }, + bargap: 0.1, xaxis: { automargin: true, fixedrange: true, gridcolor: PLOT_GRID_COLOR, linecolor: PLOT_GRID_COLOR, + showticklabels: false, }, yaxis: { fixedrange: true, + range: props.yRange, gridcolor: PLOT_GRID_COLOR, linecolor: PLOT_GRID_COLOR, + showticklabels: true, + tickfont: { + size: 10, + }, + }, + title: { + multiline: true, + text: props.plotlyData.name ?? '', + font: {size: 12}, + xref: 'paper', + x: 0.5, + y: 1, + yanchor: 'top', + pad: {t: 2}, }, }; - }, [props.height]); + }, [props.height, props.plotlyData, props.yRange]); + const plotlyConfig = useMemo(() => { return { displayModeBar: false, @@ -57,11 +61,11 @@ export const PlotlyBarPlot: React.FC<{ useEffect(() => { Plotly.newPlot( divRef.current as any, - plotlyData, + [props.plotlyData], plotlyLayout, plotlyConfig ); - }, [plotlyConfig, plotlyData, plotlyLayout]); + }, [plotlyConfig, props.plotlyData, plotlyLayout]); return
; }; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/PlotlyRadarPlot.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/PlotlyRadarPlot.tsx index d459d1354f1..47d0fa3f10c 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/PlotlyRadarPlot.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/PlotlyRadarPlot.tsx @@ -31,13 +31,13 @@ export const PlotlyRadarPlot: React.FC<{ }, [props.data]); const plotlyLayout: Partial = useMemo(() => { return { - height: props.height, + height: props.height - 40, showlegend: false, margin: { - l: 60, - r: 0, + l: 20, + r: 20, b: 30, - t: 30, + t: 20, pad: 0, }, polar: { diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/SummaryPlotsSection.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/SummaryPlotsSection.tsx index 5bfaa8fcb04..c23ffcce04d 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/SummaryPlotsSection.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/SummaryPlotsSection.tsx @@ -1,15 +1,6 @@ import {Box} from '@material-ui/core'; -import {Popover} from '@mui/material'; -import {Switch} from '@wandb/weave/components'; import {Button} from '@wandb/weave/components/Button'; -import { - DraggableGrow, - DraggableHandle, -} from '@wandb/weave/components/DraggablePopups'; -import {TextField} from '@wandb/weave/components/Form/TextField'; import {Tailwind} from '@wandb/weave/components/Tailwind'; -import {maybePluralize} from '@wandb/weave/core/util/string'; -import classNames from 'classnames'; import React, {useEffect, useMemo, useRef, useState} from 'react'; import {buildCompositeMetricsMap} from '../../compositeMetricsUtil'; @@ -27,6 +18,7 @@ import { resolveSummaryMetricValueForEvaluateCall, } from '../../ecpUtil'; import {HorizontalBox, VerticalBox} from '../../Layout'; +import {MetricsSelector} from './MetricsSelector'; import {PlotlyBarPlot} from './PlotlyBarPlot'; import {PlotlyRadarPlot, RadarPlotData} from './PlotlyRadarPlot'; @@ -36,15 +28,12 @@ import {PlotlyRadarPlot, RadarPlotData} from './PlotlyRadarPlot'; export const SummaryPlots: React.FC<{ state: EvaluationComparisonState; setSelectedMetrics: (newModel: Record) => void; -}> = props => { - const {radarData, allMetricNames} = useNormalizedPlotDataFromMetrics( - props.state - ); - const {selectedMetrics} = props.state; - const setSelectedMetrics = props.setSelectedMetrics; +}> = ({state, setSelectedMetrics}) => { + const {radarData, allMetricNames} = useNormalizedPlotDataFromMetrics(state); + const {selectedMetrics} = state; + // Initialize selectedMetrics if null useEffect(() => { - // If selectedMetrics is null, we should show all metrics if (selectedMetrics == null) { setSelectedMetrics( Object.fromEntries(Array.from(allMetricNames).map(m => [m, true])) @@ -52,10 +41,184 @@ export const SummaryPlots: React.FC<{ } }, [selectedMetrics, setSelectedMetrics, allMetricNames]); - // filter down the plotlyRadarData to only include the selected metrics, after - // computation, to allow quick addition/removal of metrics - const filteredPlotlyRadarData = useMemo(() => { - const filteredData: RadarPlotData = {}; + const filteredData = useFilteredData(radarData, selectedMetrics); + const normalizedRadarData = normalizeDataForRadarPlot(filteredData); + const barPlotData = useBarPlotData(filteredData); + + const { + containerRef, + isInitialRender, + plotsPerPage, + currentPage, + setCurrentPage, + } = useContainerDimensions(); + + const {plotsToShow, totalPlots, startIndex, endIndex, totalPages} = + usePaginatedPlots( + normalizedRadarData, + barPlotData, + plotsPerPage, + currentPage + ); + + // Render placeholder during initial render + if (isInitialRender) { + return
; + } + + return ( + + +
+ {plotsToShow} +
+ setCurrentPage(prev => Math.max(prev - 1, 0))} + onNextPage={() => + setCurrentPage(prev => Math.min(prev + 1, totalPages - 1)) + } + /> +
+ ); +}; + +const SectionHeader: React.FC<{ + selectedMetrics: Record | undefined; + setSelectedMetrics: (newModel: Record) => void; + allMetrics: string[]; +}> = ({selectedMetrics, setSelectedMetrics, allMetrics}) => ( + + + Summary Metrics + + +
+
Configure displayed metrics
+ +
+
+
+); + +const RadarPlotBox: React.FC<{data: RadarPlotData}> = ({data}) => ( + + + +); + +const BarPlotBox: React.FC<{ + plot: {plotlyData: Plotly.Data; yRange: [number, number]}; +}> = ({plot}) => ( + + + +); + +const PaginationControls: React.FC<{ + currentPage: number; + totalPages: number; + startIndex: number; + endIndex: number; + totalPlots: number; + onPrevPage: () => void; + onNextPage: () => void; +}> = ({ + currentPage, + totalPages, + startIndex, + endIndex, + totalPlots, + onPrevPage, + onNextPage, +}) => ( + + + +
+
+
+
+
+); + +const useFilteredData = ( + radarData: RadarPlotData, + selectedMetrics: Record | undefined +) => + useMemo(() => { + const data: RadarPlotData = {}; for (const [callId, metricBin] of Object.entries(radarData)) { const metrics: {[metric: string]: number} = {}; for (const [metric, value] of Object.entries(metricBin.metrics)) { @@ -64,253 +227,195 @@ export const SummaryPlots: React.FC<{ } } if (Object.keys(metrics).length > 0) { - filteredData[callId] = { + data[callId] = { metrics, name: metricBin.name, color: metricBin.color, }; } } - return filteredData; + return data; }, [radarData, selectedMetrics]); - return ( - - - - Summary Metrics - - -
-
Configure displayed metrics
- -
-
-
- - - - - - - - -
- ); -}; +function getMetricValuesFromRadarData(radarData: RadarPlotData): { + [metric: string]: number[]; +} { + const metricValues: {[metric: string]: number[]} = {}; + // Gather all values for each metric + Object.values(radarData).forEach(callData => { + Object.entries(callData.metrics).forEach(([metric, value]) => { + if (!metricValues[metric]) { + metricValues[metric] = []; + } + metricValues[metric].push(value); + }); + }); + return metricValues; +} -const MetricsSelector: React.FC<{ - setSelectedMetrics: (newModel: Record) => void; - selectedMetrics: Record | undefined; - allMetrics: string[]; -}> = ({setSelectedMetrics, selectedMetrics, allMetrics}) => { - const [search, setSearch] = useState(''); - - const ref = useRef(null); - const [anchorEl, setAnchorEl] = useState(null); - const onClick = (event: React.MouseEvent) => { - setAnchorEl(anchorEl ? null : ref.current); - setSearch(''); +function getMetricMinsFromRadarData(radarData: RadarPlotData): { + [metric: string]: number; +} { + const metricValues = getMetricValuesFromRadarData(radarData); + const metricMins: {[metric: string]: number} = {}; + Object.entries(metricValues).forEach(([metric, values]) => { + metricMins[metric] = Math.min(...values); + }); + return metricMins; +} + +function normalizeDataForRadarPlot(radarData: RadarPlotData): RadarPlotData { + const metricMins = getMetricMinsFromRadarData(radarData); + + const normalizedData: RadarPlotData = {}; + Object.entries(radarData).forEach(([callId, callData]) => { + normalizedData[callId] = { + name: callData.name, + color: callData.color, + metrics: {}, + }; + + Object.entries(callData.metrics).forEach(([metric, value]) => { + const min = metricMins[metric]; + // Only shift values if there are negative values + const normalizedValue = min < 0 ? value - min : value; + normalizedData[callId].metrics[metric] = normalizedValue; + }); + }); + + return normalizedData; +} + +const useBarPlotData = (filteredData: RadarPlotData) => + useMemo(() => { + const metrics: { + [metric: string]: { + callIds: string[]; + values: number[]; + name: string; + colors: string[]; + }; + } = {}; + + // Reorganize data by metric instead of by call + for (const [callId, metricBin] of Object.entries(filteredData)) { + for (const [metric, value] of Object.entries(metricBin.metrics)) { + if (!metrics[metric]) { + metrics[metric] = {callIds: [], values: [], name: metric, colors: []}; + } + metrics[metric].callIds.push(callId); + metrics[metric].values.push(value); + metrics[metric].colors.push(metricBin.color); + } + } + + // Convert metrics object to Plotly data format + return Object.entries(metrics).map(([metric, metricBin]) => { + const maxY = Math.max(...metricBin.values) * 1.1; + const minY = Math.min(...metricBin.values, 0); + const plotlyData: Plotly.Data = { + type: 'bar', + y: metricBin.values, + x: metricBin.callIds, + text: metricBin.values.map(value => value.toFixed(3)), + textposition: 'outside', + textfont: {size: 14, color: 'black'}, + name: metric, + marker: {color: metricBin.colors}, + }; + return {plotlyData, yRange: [minY, maxY] as [number, number]}; + }); + }, [filteredData]); + +const useContainerDimensions = () => { + const containerRef = useRef(null); + const [containerWidth, setContainerWidth] = useState(0); + const [isInitialRender, setIsInitialRender] = useState(true); + const [currentPage, setCurrentPage] = useState(0); + + useEffect(() => { + const updateWidth = () => { + if (containerRef.current) { + setContainerWidth(containerRef.current.offsetWidth); + } + }; + + updateWidth(); + setIsInitialRender(false); + + window.addEventListener('resize', updateWidth); + return () => window.removeEventListener('resize', updateWidth); + }, []); + + const plotsPerPage = useMemo(() => { + return Math.max(1, Math.floor(containerWidth / PLOT_HEIGHT)); + }, [containerWidth]); + + return { + containerRef, + isInitialRender, + plotsPerPage, + currentPage, + setCurrentPage, }; - const open = Boolean(anchorEl); - const id = open ? 'simple-popper' : undefined; +}; - const filteredCols = search - ? allMetrics.filter(col => col.toLowerCase().includes(search.toLowerCase())) - : allMetrics; +const usePaginatedPlots = ( + filteredData: RadarPlotData, + barPlotData: Array<{plotlyData: Plotly.Data; yRange: [number, number]}>, + plotsPerPage: number, + currentPage: number +) => { + const radarPlotWidth = 2; + const totalBarPlots = barPlotData.length; + const totalPlotWidth = radarPlotWidth + totalBarPlots; + const totalPages = Math.ceil(totalPlotWidth / plotsPerPage); - const shownMetrics = Object.values(selectedMetrics ?? {}).filter(Boolean); + const plotsToShow = useMemo(() => { + // First page always shows radar plot + if (currentPage === 0) { + const availableSpace = plotsPerPage - radarPlotWidth; + return [ + , + ...barPlotData + .slice(0, availableSpace) + .map((plot, index) => ( + + )), + ]; + } else { + // Subsequent pages show only bar plots + const startIdx = + (currentPage - 1) * plotsPerPage + (plotsPerPage - radarPlotWidth); + const endIdx = startIdx + plotsPerPage; + return barPlotData + .slice(startIdx, endIdx) + .map((plot, index) => ( + + )); + } + }, [currentPage, plotsPerPage, filteredData, barPlotData]); - const numHidden = allMetrics.length - shownMetrics.length; - const buttonSuffix = search ? `(${filteredCols.length})` : 'all'; + // Calculate pagination details + const totalPlots = barPlotData.length + 1; // +1 for the radar plot + const startIndex = + currentPage === 0 ? 1 : Math.min(plotsPerPage + 1, totalPlots); + const endIndex = + currentPage === 0 + ? Math.min(plotsToShow.length, totalPlots) + : Math.min(startIndex + plotsToShow.length - 1, totalPlots); - return ( - <> - - -
- -
-
- - - - ); + return {plotsToShow, totalPlots, startIndex, endIndex, totalPages}; }; -const normalizeValues = (values: Array): number[] => { +function normalizeValues(values: Array): number[] { // find the max value // find the power of 2 that is greater than the max value // divide all values by that power of 2 const maxVal = Math.max(...(values.filter(v => v !== undefined) as number[])); const maxPower = Math.ceil(Math.log2(maxVal)); return values.map(val => (val ? val / 2 ** maxPower : 0)); -}; +} const useNormalizedPlotDataFromMetrics = ( state: EvaluationComparisonState