From 6f5796da7807c530a56d68bd752b68e41a6ac0ea Mon Sep 17 00:00:00 2001 From: KyleGoyette Date: Wed, 4 Dec 2024 14:56:56 -0800 Subject: [PATCH 01/60] chore(app): Add launch is Active to `useProjectSidebar` (#3005) * hide launch by default * add to usememo dep * prettier --- weave-js/src/components/FancyPage/useProjectSidebar.ts | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/weave-js/src/components/FancyPage/useProjectSidebar.ts b/weave-js/src/components/FancyPage/useProjectSidebar.ts index 34d9bb890b6..19dce4215df 100644 --- a/weave-js/src/components/FancyPage/useProjectSidebar.ts +++ b/weave-js/src/components/FancyPage/useProjectSidebar.ts @@ -10,7 +10,8 @@ export const useProjectSidebar = ( hasModelsData: boolean, hasWeaveData: boolean, hasTraceBackend: boolean = true, - hasModelsAccess: boolean = true + hasModelsAccess: boolean = true, + isLaunchActive: boolean = false ): FancyPageSidebarItem[] => { // Should show models sidebar items if we have models data or if we don't have a trace backend let showModelsSidebarItems = hasModelsData || !hasTraceBackend; @@ -68,7 +69,7 @@ export const useProjectSidebar = ( type: 'button' as const, name: 'Jobs', slug: 'jobs', - isShown: isModelsOnly, + isShown: isModelsOnly && isLaunchActive, isDisabled: viewingRestricted, iconName: IconNames.FlashBolt, }, @@ -250,5 +251,6 @@ export const useProjectSidebar = ( viewingRestricted, isModelsOnly, showWeaveSidebarItems, + isLaunchActive, ]); }; From f716e17908b015d4a481bb2eaaf6a56246f170f5 Mon Sep 17 00:00:00 2001 From: Ben Sherman Date: Wed, 4 Dec 2024 15:46:31 -0800 Subject: [PATCH 02/60] chore(ui): right drawer doesn't occlude call metrics charts (#3149) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Description Change size of metric charts section to be full container width instead of fixed based on window size. Create plotly chart directly in the component render for metric charts, instead of in an effect that depends on only a few of the props. Adds a 200ms debounce to all metric chart updates so resizing the drawer doesn't cause excessive re-rendering of the plots. before: ![Screenshot 2024-12-04 at 2 15 04 PM](https://github.com/user-attachments/assets/49a77718-a4e2-486f-af37-8dfae977d63d) after: ![Screenshot 2024-12-04 at 2 15 26 PM](https://github.com/user-attachments/assets/fb2288b4-3e44-475f-b40f-b310ccdcb129) --- .../Browse3/pages/CallsPage/CallsCharts.tsx | 2 +- .../Home/Browse3/pages/CallsPage/Charts.tsx | 138 +++++++++--------- 2 files changed, 71 insertions(+), 69 deletions(-) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/CallsCharts.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/CallsCharts.tsx index bdebf29f697..0519dd2161d 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/CallsCharts.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/CallsCharts.tsx @@ -182,7 +182,7 @@ export const CallsCharts = ({ return ( {/* setting the width to the width of the screen minus the sidebar width because of overflow: 'hidden' properties in SimplePageLayout causing issues */} -
+
{charts}
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/Charts.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/Charts.tsx index 7d1e9313069..c04ce275483 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/Charts.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/Charts.tsx @@ -2,7 +2,7 @@ import {quantile} from 'd3-array'; import _ from 'lodash'; import moment from 'moment'; import * as Plotly from 'plotly.js'; -import React, {useEffect, useMemo, useRef} from 'react'; +import React, {useMemo, useRef} from 'react'; import { BLUE_500, @@ -64,6 +64,8 @@ const Y_AXIS_STYLE: Partial = { zeroline: false, }; +const PLOT_DEBOUNCE_MS = 200; + export const calculateBinSize = ( data: ChartDataLatency[] | ChartDataErrors[] | ChartDataRequests[], targetBinCount = 15 @@ -147,28 +149,28 @@ export const LatencyPlotlyChart: React.FC<{ ]; }, [chartData, binSize]); - useEffect(() => { - const plotlyLayout: Partial = { - height, - margin: CHART_MARGIN_STYLE, - xaxis: X_AXIS_STYLE_WITH_SPIKES, - yaxis: Y_AXIS_STYLE, - hovermode: 'x unified', - showlegend: false, - hoverlabel: { - bordercolor: MOON_200, - }, - }; - - const plotlyConfig: Partial = { - displayModeBar: false, - responsive: true, - }; - - if (divRef.current) { - Plotly.newPlot(divRef.current, plotlyData, plotlyLayout, plotlyConfig); - } - }, [plotlyData, height]); + const plotlyLayout: Partial = { + height, + margin: CHART_MARGIN_STYLE, + xaxis: X_AXIS_STYLE_WITH_SPIKES, + yaxis: Y_AXIS_STYLE, + hovermode: 'x unified', + showlegend: false, + hoverlabel: { + bordercolor: MOON_200, + }, + }; + + const plotlyConfig: Partial = { + displayModeBar: false, + responsive: true, + }; + + if (divRef.current) { + _.debounce(() => { + Plotly.newPlot(divRef.current!, plotlyData, plotlyLayout, plotlyConfig); + }, PLOT_DEBOUNCE_MS)(); + } return
; }; @@ -206,29 +208,29 @@ export const ErrorPlotlyChart: React.FC<{ ]; }, [chartData, binSize]); - useEffect(() => { - const plotlyLayout: Partial = { - height, - margin: CHART_MARGIN_STYLE, - bargap: 0.2, - xaxis: X_AXIS_STYLE, - yaxis: Y_AXIS_STYLE, - hovermode: 'x unified', - hoverlabel: { - bordercolor: MOON_200, - }, - dragmode: 'zoom', - }; - - const plotlyConfig: Partial = { - displayModeBar: false, - responsive: true, - }; - - if (divRef.current) { - Plotly.newPlot(divRef.current, plotlyData, plotlyLayout, plotlyConfig); - } - }, [plotlyData, height]); + const plotlyLayout: Partial = { + height, + margin: CHART_MARGIN_STYLE, + bargap: 0.2, + xaxis: X_AXIS_STYLE, + yaxis: Y_AXIS_STYLE, + hovermode: 'x unified', + hoverlabel: { + bordercolor: MOON_200, + }, + dragmode: 'zoom', + }; + + const plotlyConfig: Partial = { + displayModeBar: false, + responsive: true, + }; + + if (divRef.current) { + _.debounce(() => { + Plotly.newPlot(divRef.current!, plotlyData, plotlyLayout, plotlyConfig); + }, PLOT_DEBOUNCE_MS)(); + } return
; }; @@ -266,28 +268,28 @@ export const RequestsPlotlyChart: React.FC<{ ]; }, [chartData, binSize]); - useEffect(() => { - const plotlyLayout: Partial = { - height, - margin: CHART_MARGIN_STYLE, - xaxis: X_AXIS_STYLE, - yaxis: Y_AXIS_STYLE, - bargap: 0.2, - hovermode: 'x unified', - hoverlabel: { - bordercolor: MOON_200, - }, - }; - - const plotlyConfig: Partial = { - displayModeBar: false, - responsive: true, - }; - - if (divRef.current) { - Plotly.newPlot(divRef.current, plotlyData, plotlyLayout, plotlyConfig); - } - }, [plotlyData, height]); + const plotlyLayout: Partial = { + height, + margin: CHART_MARGIN_STYLE, + xaxis: X_AXIS_STYLE, + yaxis: Y_AXIS_STYLE, + bargap: 0.2, + hovermode: 'x unified', + hoverlabel: { + bordercolor: MOON_200, + }, + }; + + const plotlyConfig: Partial = { + displayModeBar: false, + responsive: true, + }; + + if (divRef.current) { + _.debounce(() => { + Plotly.newPlot(divRef.current!, plotlyData, plotlyLayout, plotlyConfig); + }, PLOT_DEBOUNCE_MS)(); + } return
; }; From 7c10a7ef2c2fa6677bc13c8c8be61fc45ffb33cf Mon Sep 17 00:00:00 2001 From: Josiah Lee Date: Wed, 4 Dec 2024 17:00:54 -0800 Subject: [PATCH 03/60] add o1-preview back and latest 4o (#3151) --- .../Browse3/pages/PlaygroundPage/llmMaxTokens.ts | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/llmMaxTokens.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/llmMaxTokens.ts index 4b12110cc38..1fbb329b3f0 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/llmMaxTokens.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/llmMaxTokens.ts @@ -162,6 +162,11 @@ export const LLM_MAX_TOKENS = { max_tokens: 4096, supports_function_calling: true, }, + 'gpt-4o-2024-11-20': { + provider: 'openai', + max_tokens: 4096, + supports_function_calling: true, + }, 'groq/gemma-7b-it': { provider: 'groq', max_tokens: 8192, @@ -207,6 +212,16 @@ export const LLM_MAX_TOKENS = { max_tokens: 65536, supports_function_calling: true, }, + 'o1-preview-2024-09-12': { + provider: 'openai', + max_tokens: 32768, + supports_function_calling: true, + }, + 'o1-preview': { + provider: 'openai', + max_tokens: 32768, + supports_function_calling: true, + }, 'ai21.j2-mid-v1': { provider: 'bedrock', From a259e4a1cc3be26abe08b76b8c6e81552a477206 Mon Sep 17 00:00:00 2001 From: Ben Sherman Date: Thu, 5 Dec 2024 09:11:43 -0800 Subject: [PATCH 04/60] chore(ui): add custom styles to column menu items (#3152) --- .../pages/CallsPage/CallsCustomColumnMenu.tsx | 21 +++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/CallsCustomColumnMenu.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/CallsCustomColumnMenu.tsx index 30420e3cbcd..79cdbba5249 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/CallsCustomColumnMenu.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/CallsCustomColumnMenu.tsx @@ -3,6 +3,7 @@ * as we implement our own UI outside the grid for this. We still want the "Hide column" item, * which is inconveniently tied to the "Manage columns" item in `columnMenuColumnsItem`. */ +import {createTheme, ThemeProvider} from '@mui/material/styles'; import { GridColumnMenu, GridColumnMenuHideItem, @@ -12,11 +13,27 @@ import React from 'react'; type Slots = Record | null>; -// See: https://mui.com/x/react-data-grid/column-menu/#customize-column-menu-items +const columnMenuTheme = createTheme({ + components: { + MuiTypography: { + styleOverrides: { + root: { + fontSize: '14px', + fontFamily: 'Source Sans Pro', + fontWeight: 400, + }, + }, + }, + }, +}); export const CallsCustomColumnMenu = (props: GridColumnMenuProps) => { const slots: Slots = {columnMenuColumnsItem: null}; if (props.colDef.hideable ?? true) { slots.columnMenuUserItem = GridColumnMenuHideItem; } - return ; + return ( + + + + ); }; From 535f81186f8de3d214c5d0eb5c7182a699755314 Mon Sep 17 00:00:00 2001 From: Andrew Truong Date: Thu, 5 Dec 2024 22:04:40 -0500 Subject: [PATCH 05/60] chore(weave): Temporarily pin wandb version in mypy to unblock CI (#3160) --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2c45dda5fa0..a5649a71720 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -16,7 +16,7 @@ repos: hooks: - id: mypy additional_dependencies: - [types-pkg-resources==0.1.3, types-all, wandb>=0.15.5] + [types-pkg-resources==0.1.3, types-all, wandb>=0.15.5, wandb<0.19.0] # Note: You have to update pyproject.toml[tool.mypy] too! args: ["--config-file=pyproject.toml"] exclude: (.*pyi$)|(weave_query)|(tests)|(examples) From 2b1fa5a9e5eb2126472dacf8c9395d4728ede197 Mon Sep 17 00:00:00 2001 From: Andrew Truong Date: Thu, 5 Dec 2024 22:13:41 -0500 Subject: [PATCH 06/60] chore(weave): Tidy op.py (#3143) --- weave/trace/op.py | 28 +++++++--------------------- 1 file changed, 7 insertions(+), 21 deletions(-) diff --git a/weave/trace/op.py b/weave/trace/op.py index 45147789dee..5e33e8bdbf8 100644 --- a/weave/trace/op.py +++ b/weave/trace/op.py @@ -34,8 +34,6 @@ logger = logging.getLogger(__name__) -WEAVE_KWARGS_KEY = "__weave" - if TYPE_CHECKING: from weave.trace.weave_client import Call, CallsIter @@ -54,17 +52,6 @@ except ImportError: ANTHROPIC_NOT_GIVEN = None -try: - # https://github.com/search?q=repo:mistralai/client-python%20Final&type=code - from mistralai.types.basemodel import UNSET # type: ignore - - MISTRAL_NOT_GIVEN = UNSET # type: ignore -except ImportError: - MISTRAL_NOT_GIVEN = None - -MISTRAL_NOT_GIVEN = None - - try: from cerebras.cloud.sdk._types import NOT_GIVEN as CEREBRAS_NOT_GIVEN except ImportError: @@ -105,14 +92,13 @@ class ProcessedInputs: def _value_is_sentinel(param: Any) -> bool: - return ( - param.default is None - or param.default is OPENAI_NOT_GIVEN - or param.default is COHERE_NOT_GIVEN - or param.default is ANTHROPIC_NOT_GIVEN - or param.default is MISTRAL_NOT_GIVEN - or param.default is CEREBRAS_NOT_GIVEN - or param.default is Ellipsis + return param.default in ( + None, + Ellipsis, + OPENAI_NOT_GIVEN, + COHERE_NOT_GIVEN, + ANTHROPIC_NOT_GIVEN, + CEREBRAS_NOT_GIVEN, ) From 43ab785e1c046e0bdb06da4100f10e27707f75ae Mon Sep 17 00:00:00 2001 From: Jamie Rasmussen <112953339+jamie-rasmussen@users.noreply.github.com> Date: Fri, 6 Dec 2024 10:39:35 -0600 Subject: [PATCH 07/60] refactor(ui): break an import cycle (#3155) --- .../PagePanelComponents/Home/Browse3.tsx | 61 +------------------ .../Home/Browse3/pages/CallPage/CallPage.tsx | 2 +- .../pages/CallPage/PaginationControls.tsx | 2 +- .../Browse3/pages/CallsPage/CallsTable.tsx | 2 +- .../Home/TableRowSelectionContext.tsx | 61 +++++++++++++++++++ 5 files changed, 65 insertions(+), 63 deletions(-) create mode 100644 weave-js/src/components/PagePanelComponents/Home/TableRowSelectionContext.tsx diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx index 976e232716a..3517f4d3b9c 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx @@ -104,6 +104,7 @@ import { WFDataModelAutoProvider, } from './Browse3/pages/wfReactInterface/context'; import {useHasTraceServerClientContext} from './Browse3/pages/wfReactInterface/traceServerClientContext'; +import {TableRowSelectionProvider} from './TableRowSelectionContext'; import {useDrawerResize} from './useDrawerResize'; LicenseInfo.setLicenseKey( @@ -1149,63 +1150,3 @@ const Browse3Breadcrumbs: FC = props => { ); }; - -export const TableRowSelectionContext = React.createContext<{ - rowIdsConfigured: boolean; - rowIdInTable: (id: string) => boolean; - setRowIds?: (rowIds: string[]) => void; - getNextRowId?: (currentId: string) => string | null; - getPreviousRowId?: (currentId: string) => string | null; -}>({ - rowIdsConfigured: false, - rowIdInTable: (id: string) => false, - setRowIds: () => {}, - getNextRowId: () => null, - getPreviousRowId: () => null, -}); - -const TableRowSelectionProvider: FC<{children: React.ReactNode}> = ({ - children, -}) => { - const [rowIds, setRowIds] = useState([]); - const rowIdsConfigured = useMemo(() => rowIds.length > 0, [rowIds]); - const rowIdInTable = useCallback( - (currentId: string) => rowIds.includes(currentId), - [rowIds] - ); - - const getNextRowId = useCallback( - (currentId: string) => { - const currentIndex = rowIds.indexOf(currentId); - if (currentIndex !== -1) { - return rowIds[currentIndex + 1]; - } - return null; - }, - [rowIds] - ); - - const getPreviousRowId = useCallback( - (currentId: string) => { - const currentIndex = rowIds.indexOf(currentId); - if (currentIndex !== -1) { - return rowIds[currentIndex - 1]; - } - return null; - }, - [rowIds] - ); - - return ( - - {children} - - ); -}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx index 545c64eb4b9..3e4a74a4885 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx @@ -17,7 +17,7 @@ import {makeRefCall} from '../../../../../../util/refs'; import {Button} from '../../../../../Button'; import {Tailwind} from '../../../../../Tailwind'; import {Browse2OpDefCode} from '../../../Browse2/Browse2OpDefCode'; -import {TableRowSelectionContext} from '../../../Browse3'; +import {TableRowSelectionContext} from '../../../TableRowSelectionContext'; import { FEEDBACK_EXPAND_PARAM, TRACETREE_PARAM, diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/PaginationControls.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/PaginationControls.tsx index 9a2222e7e0a..7254851929a 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/PaginationControls.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/PaginationControls.tsx @@ -3,7 +3,7 @@ import {Button} from '@wandb/weave/components/Button'; import React, {FC, useCallback, useContext, useEffect} from 'react'; import {useHistory} from 'react-router-dom'; -import {TableRowSelectionContext} from '../../../Browse3'; +import {TableRowSelectionContext} from '../../../TableRowSelectionContext'; import { FEEDBACK_EXPAND_PARAM, TRACETREE_PARAM, diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/CallsTable.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/CallsTable.tsx index d7db3324ac9..e6e62180f64 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/CallsTable.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/CallsTable.tsx @@ -45,7 +45,7 @@ import {useViewerInfo} from '../../../../../../common/hooks/useViewerInfo'; import {A, TargetBlank} from '../../../../../../common/util/links'; import {TailwindContents} from '../../../../../Tailwind'; import {flattenObjectPreservingWeaveTypes} from '../../../Browse2/browse2Util'; -import {TableRowSelectionContext} from '../../../Browse3'; +import {TableRowSelectionContext} from '../../../TableRowSelectionContext'; import { useWeaveflowCurrentRouteContext, WeaveflowPeekContext, diff --git a/weave-js/src/components/PagePanelComponents/Home/TableRowSelectionContext.tsx b/weave-js/src/components/PagePanelComponents/Home/TableRowSelectionContext.tsx new file mode 100644 index 00000000000..a0e14a3367a --- /dev/null +++ b/weave-js/src/components/PagePanelComponents/Home/TableRowSelectionContext.tsx @@ -0,0 +1,61 @@ +import React, {FC, useCallback, useMemo, useState} from 'react'; + +export const TableRowSelectionContext = React.createContext<{ + rowIdsConfigured: boolean; + rowIdInTable: (id: string) => boolean; + setRowIds?: (rowIds: string[]) => void; + getNextRowId?: (currentId: string) => string | null; + getPreviousRowId?: (currentId: string) => string | null; +}>({ + rowIdsConfigured: false, + rowIdInTable: (id: string) => false, + setRowIds: () => {}, + getNextRowId: () => null, + getPreviousRowId: () => null, +}); + +export const TableRowSelectionProvider: FC<{children: React.ReactNode}> = ({ + children, +}) => { + const [rowIds, setRowIds] = useState([]); + const rowIdsConfigured = useMemo(() => rowIds.length > 0, [rowIds]); + const rowIdInTable = useCallback( + (currentId: string) => rowIds.includes(currentId), + [rowIds] + ); + + const getNextRowId = useCallback( + (currentId: string) => { + const currentIndex = rowIds.indexOf(currentId); + if (currentIndex !== -1) { + return rowIds[currentIndex + 1]; + } + return null; + }, + [rowIds] + ); + + const getPreviousRowId = useCallback( + (currentId: string) => { + const currentIndex = rowIds.indexOf(currentId); + if (currentIndex !== -1) { + return rowIds[currentIndex - 1]; + } + return null; + }, + [rowIds] + ); + + return ( + + {children} + + ); +}; From a06b6ace52988bff45012d78a29c9adbbeb32249 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Fri, 6 Dec 2024 08:57:21 -0800 Subject: [PATCH 08/60] chore(weave): Improve large object viewer renders (#3102) * init * init * Larger refactor * Larger refactor * Larger refactor * Larger refactor * Larger refactor --- .../Browse3/pages/CallPage/CallDetails.tsx | 41 ++----- .../Browse3/pages/CallPage/ObjectViewer.tsx | 81 ++++++++----- .../pages/CallPage/ObjectViewerSection.tsx | 4 +- .../traceServerCachingClient.ts | 106 ++++++++++++++++++ .../wfReactInterface/traceServerClient.ts | 4 +- .../typeViews/CustomWeaveTypeDispatcher.tsx | 8 ++ .../PIL.Image.Image/PILImageImage.tsx | 5 +- 7 files changed, 183 insertions(+), 66 deletions(-) create mode 100644 weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerCachingClient.ts diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallDetails.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallDetails.tsx index 8a9c011b5de..26be9f0c5d0 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallDetails.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallDetails.tsx @@ -1,4 +1,3 @@ -import {Typography} from '@mui/material'; import Box from '@mui/material/Box'; import _ from 'lodash'; import React, {FC, useContext, useMemo} from 'react'; @@ -10,9 +9,7 @@ import {Button} from '../../../../../Button'; import {useWeaveflowRouteContext, WeaveflowPeekContext} from '../../context'; import {CustomWeaveTypeProjectContext} from '../../typeViews/CustomWeaveTypeDispatcher'; import {CallsTable} from '../CallsPage/CallsTable'; -import {KeyValueTable} from '../common/KeyValueTable'; -import {CallLink, opNiceName} from '../common/Links'; -import {CenteredAnimatedLoader} from '../common/Loader'; +import {CallLink} from '../common/Links'; import {useWFHooks} from '../wfReactInterface/context'; import {CallSchema} from '../wfReactInterface/wfDataModelHooksInterface'; import {ButtonOverlay} from './ButtonOverlay'; @@ -20,6 +17,8 @@ import {ExceptionDetails, getExceptionInfo} from './Exceptions'; import {ObjectViewerSection} from './ObjectViewerSection'; import {OpVersionText} from './OpVersionText'; +const HEADER_HEIGHT_BUFFER = 60; + const Heading = styled.div` color: ${MOON_800}; font-weight: 600; @@ -104,7 +103,7 @@ export const CallDetails: FC<{ columns ); - const {singularChildCalls, multipleChildCallOpRefs} = useMemo( + const {multipleChildCallOpRefs} = useMemo( () => callGrouping(!childCalls.loading ? childCalls.result ?? [] : []), [childCalls.loading, childCalls.result] ); @@ -133,6 +132,7 @@ export const CallDetails: FC<{ 0 ? HEADER_HEIGHT_BUFFER : 0 + }px)`, p: 2, }}> {'traceback' in excInfo ? ( @@ -201,7 +204,6 @@ export const CallDetails: FC<{ sx={{ flex: '0 0 auto', height: '500px', - maxHeight: '95%', p: 2, display: 'flex', flexDirection: 'column', @@ -240,33 +242,6 @@ export const CallDetails: FC<{ ); })} - {childCalls.loading && } - {/* Disabling display of singular children while we decide if we want them here. */} - {false && singularChildCalls.length > 0 && ( - - {multipleChildCallOpRefs.length === 0 ? ( - Child calls - ) : ( - Singular child calls - )} - {singularChildCalls.map(c => ( - - - - ))} - - )} ); diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ObjectViewer.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ObjectViewer.tsx index d741db64771..e8f39ad2880 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ObjectViewer.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ObjectViewer.tsx @@ -23,7 +23,11 @@ import {LoadingDots} from '../../../../../LoadingDots'; import {Browse2OpDefCode} from '../../../Browse2/Browse2OpDefCode'; import {isWeaveRef} from '../../filters/common'; import {StyledDataGrid} from '../../StyledDataGrid'; -import {isCustomWeaveTypePayload} from '../../typeViews/customWeaveType.types'; +import { + CustomWeaveTypePayload, + isCustomWeaveTypePayload, +} from '../../typeViews/customWeaveType.types'; +import {getCustomWeaveTypePreferredRowHeight} from '../../typeViews/CustomWeaveTypeDispatcher'; import { LIST_INDEX_EDGE_NAME, OBJECT_ATTR_EDGE_NAME, @@ -48,6 +52,10 @@ import { } from './traverse'; import {ValueView} from './ValueView'; +const DEFAULT_ROW_HEIGHT = 38; +const CODE_ROW_HEIGHT = 350; +const TABLE_ROW_HEIGHT = 450; + type Data = Record | any[]; type ObjectViewerProps = { @@ -440,6 +448,47 @@ export const ObjectViewer = ({ }); }, [apiRef, expandedIds, updateRowExpand]); + // Per https://mui.com/x/react-data-grid/row-height/#dynamic-row-height, always + // memoize the getRowHeight function. + const getRowHeight = useCallback((params: GridRowHeightParams) => { + const isNonRefString = + params.model.valueType === 'string' && !isWeaveRef(params.model.value); + const isArray = params.model.valueType === 'array'; + const isTableRef = + isWeaveRef(params.model.value) && + (parseRefMaybe(params.model.value) as any).weaveKind === 'table'; + const {isCode} = params.model; + const isCustomWeaveType = isCustomWeaveTypePayload(params.model.value); + if (!params.model.isLeaf) { + // This is a group header, so we want to use the default height + return DEFAULT_ROW_HEIGHT; + } else if (isNonRefString) { + // This is the only special case where we will allow for dynamic height. + // Since strings have special renders that take up different amounts of + // space, we will allow for dynamic height. + return 'auto'; + } else if (isCustomWeaveType) { + const type = (params.model.value as CustomWeaveTypePayload).weave_type + .type; + const preferredRowHeight = getCustomWeaveTypePreferredRowHeight(type); + if (preferredRowHeight) { + return preferredRowHeight; + } + return DEFAULT_ROW_HEIGHT; + } else if ((isArray && USE_TABLE_FOR_ARRAYS) || isTableRef) { + // Perfectly enough space for 1 page of data rows + return TABLE_ROW_HEIGHT; + } else if (isCode) { + // Probably will get negative feedback here since code that is < 20 lines + // will have some whitespace below the code. However, we absolutely need + // to have static height for all cells else the MUI data grid will jump around + // when cleaning up virtual rows. + return CODE_ROW_HEIGHT; + } else { + return DEFAULT_ROW_HEIGHT; + } + }, []); + // Finally, we memoize the inner data grid component. This is important to // reduce the number of re-renders when the data changes. const inner = useMemo(() => { @@ -473,31 +522,9 @@ export const ObjectViewer = ({ isGroupExpandedByDefault={node => { return expandedIds.includes(node.id); }} - autoHeight columnHeaderHeight={38} - getRowHeight={(params: GridRowHeightParams) => { - const isNonRefString = - params.model.valueType === 'string' && - !isWeaveRef(params.model.value); - const isArray = params.model.valueType === 'array'; - const isTableRef = - isWeaveRef(params.model.value) && - (parseRefMaybe(params.model.value) as any).weaveKind === 'table'; - const {isCode} = params.model; - const isCustomWeaveType = isCustomWeaveTypePayload( - params.model.value - ); - if ( - isNonRefString || - (isArray && USE_TABLE_FOR_ARRAYS) || - isTableRef || - isCode || - isCustomWeaveType - ) { - return 'auto'; - } - return 38; - }} + rowHeight={DEFAULT_ROW_HEIGHT} + getRowHeight={getRowHeight} hideFooter rowSelection={false} groupingColDef={groupingColDef} @@ -517,10 +544,10 @@ export const ObjectViewer = ({ }} /> ); - }, [apiRef, rows, columns, groupingColDef, expandedIds]); + }, [apiRef, rows, columns, getRowHeight, groupingColDef, expandedIds]); // Return the inner data grid wrapped in a div with overflow hidden. - return
{inner}
; + return
{inner}
; }; // Helper function to build the base ref for a given path. This function is used diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ObjectViewerSection.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ObjectViewerSection.tsx index d95d9a841de..d02a2e881ca 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ObjectViewerSection.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ObjectViewerSection.tsx @@ -185,7 +185,7 @@ const ObjectViewerSectionNonEmpty = ({ }, [data, isExpanded]); return ( - <> + {title} diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/PlaygroundChat.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/PlaygroundChat.tsx index b5f675633fe..94cd3c17644 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/PlaygroundChat.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/PlaygroundChat.tsx @@ -115,11 +115,15 @@ export const PlaygroundChat = ({ }}> -
+
{state.traceCall && ( = ({ width: '100%', display: 'flex', justifyContent: 'space-between', - paddingBottom: '8px', }}> - - {props.headerContent} - + {props.headerContent && ( + + {props.headerContent} + + )} {(!props.hideTabsIfSingle || props.tabs.length > 1) && ( Date: Mon, 9 Dec 2024 12:44:29 -0500 Subject: [PATCH 21/60] chore(ui): Update metadata header styling (#3049) * Object metadata to Tailwind / refactor * Updated versions link and styling * Updated OpsVersionLink * Updated CallsLink to include next icon * Lint * Better flow for column sizes * Lint * No more colons --- .../Home/Browse3/pages/ObjectVersionPage.tsx | 83 ++++++++-------- .../Home/Browse3/pages/OpVersionPage.tsx | 98 ++++++++++++------- .../Home/Browse3/pages/common/Links.tsx | 22 +++-- 3 files changed, 119 insertions(+), 84 deletions(-) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionPage.tsx index 6108ed5407e..085587a64d8 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionPage.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionPage.tsx @@ -211,47 +211,50 @@ const ObjectVersionPageInner: React.FC<{ } headerContent={ - - {objectName}{' '} - {objectVersions.loading ? ( - - ) : ( - <> - [ - +
+
+

Name

+
+ +
+ {objectName} + {objectVersions.loading ? ( + + ) : ( + + ({objectVersionCount} version + {objectVersionCount !== 1 ? 's' : ''}) + + )} + - ] - - )} - - ), - Version: <>{objectVersionIndex}, - ...(refExtra - ? { - Subpath: refExtra, - } - : {}), - // 'Type Version': ( - // - // ), - }} - /> +
+
+
+
+
+

Version

+

{objectVersionIndex}

+
+ {refExtra && ( +
+

Subpath

+

{refExtra}

+
+ )} +
+ } // menuItems={[ // { diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/OpVersionPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/OpVersionPage.tsx index 1a6e4afc577..36f4e44afc5 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/OpVersionPage.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/OpVersionPage.tsx @@ -1,5 +1,6 @@ import React, {useMemo} from 'react'; +import {Icon} from '../../../../Icon'; import {LoadingDots} from '../../../../LoadingDots'; import {Tailwind} from '../../../../Tailwind'; import {NotFoundPanel} from '../NotFoundPanel'; @@ -13,7 +14,6 @@ import { import {CenteredAnimatedLoader} from './common/Loader'; import { ScrollableTabContent, - SimpleKeyValueTable, SimplePageLayoutWithHeader, } from './common/SimplePageLayout'; import {TabUseOp} from './TabUseOp'; @@ -75,49 +75,71 @@ const OpVersionPageInner: React.FC<{ - {opId}{' '} - {opVersions.loading ? ( - - ) : ( - <> - [ - - ] - - )} - - ), - Version: <>{versionIndex}, - Calls: - !callsStats.loading || opVersionCallCount > 0 ? ( - +
+
+

Name

+
+ + variant="secondary"> +
+ {opId} + {opVersions.loading ? ( + + ) : ( + + ({opVersionCount} version + {opVersionCount !== 1 ? 's' : ''}) + + )} + +
+
+
+
+
+

Version

+

{versionIndex}

+
+
+

Calls:

+ {!callsStats.loading || opVersionCallCount > 0 ? ( +
+ + +
) : ( - <> - ), - }} - /> +

-

+ )} +
+
+ } tabs={[ { diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/Links.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/Links.tsx index be78bb51367..4060735cc67 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/Links.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/Links.tsx @@ -429,6 +429,7 @@ export const ObjectVersionsLink: React.FC<{ filter?: WFHighLevelObjectVersionFilter; neverPeek?: boolean; variant?: LinkVariant; + children?: React.ReactNode; }> = props => { const {peekingRouter, baseRouter} = useWeaveflowRouteContext(); const router = props.neverPeek ? baseRouter : peekingRouter; @@ -440,9 +441,13 @@ export const ObjectVersionsLink: React.FC<{ props.project, props.filter )}> - {props.versionCount} - {props.countIsLimited ? '+' : ''} version - {props.versionCount !== 1 ? 's' : ''} + {props.children ?? ( + <> + {props.versionCount} + {props.countIsLimited ? '+' : ''} version + {props.versionCount !== 1 ? 's' : ''} + + )} ); }; @@ -455,6 +460,7 @@ export const OpVersionsLink: React.FC<{ filter?: WFHighLevelOpVersionFilter; neverPeek?: boolean; variant?: LinkVariant; + children?: React.ReactNode; }> = props => { const {peekingRouter, baseRouter} = useWeaveflowRouteContext(); const router = props.neverPeek ? baseRouter : peekingRouter; @@ -462,9 +468,13 @@ export const OpVersionsLink: React.FC<{ - {props.versionCount} - {props.countIsLimited ? '+' : ''} version - {props.versionCount !== 1 ? 's' : ''} + {props.children ?? ( + <> + {props.versionCount} + {props.countIsLimited ? '+' : ''} version + {props.versionCount !== 1 ? 's' : ''} + + )} ); }; From 7320ca9a1fcc51808103403d61923e00361a02d9 Mon Sep 17 00:00:00 2001 From: Griffin Tarpenning Date: Mon, 9 Dec 2024 10:35:15 -0800 Subject: [PATCH 22/60] chore(ui): disable refresh while refreshing (#3179) --- .../Home/Browse3/pages/CallsPage/CallsTable.tsx | 5 ++++- .../Home/Browse3/pages/CallsPage/CallsTableButtons.tsx | 4 +++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/CallsTable.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/CallsTable.tsx index 18130d64341..3632664aaf0 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/CallsTable.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/CallsTable.tsx @@ -744,7 +744,10 @@ export const CallsTable: FC<{ }} filterListItems={ - calls.refetch()} /> + calls.refetch()} + disabled={callsLoading} + /> {!hideOpSelector && ( void; -}> = ({onClick}) => { + disabled?: boolean; +}> = ({onClick, disabled}) => { return ( From 0cfe058b709b77c96b901a0975f113e4d42d85a3 Mon Sep 17 00:00:00 2001 From: Connie Lee Date: Mon, 9 Dec 2024 11:39:23 -0800 Subject: [PATCH 23/60] style(app): Update tailwind shadow configs (#3180) --- weave-js/tailwind.config.cjs | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/weave-js/tailwind.config.cjs b/weave-js/tailwind.config.cjs index b234cc3cf1a..d8a580c928c 100644 --- a/weave-js/tailwind.config.cjs +++ b/weave-js/tailwind.config.cjs @@ -13,8 +13,13 @@ module.exports = { */ boxShadow: { none: 'none', - md: '0px 12px 24px 0px #15181F29', - lg: '0px 24px 48px 0px #15181F29', + flat: '0px 4px 8px 0px #0D0F120a', // oblivion 4% + medium: '0px 12px 24px 0px #0D0F1229', // oblivion 16% + deep: '0px 24px 48px 0px #0D0F123d', // oblivion 24% + + // deprecated shadow configs + md: '0px 12px 24px 0px #15181F29', // use shadow-medium instead + lg: '0px 24px 48px 0px #15181F29', // use shadow-deep instead }, spacing: { 0: '0rem', @@ -189,17 +194,17 @@ module.exports = { }, extend: { animation: { - 'wave': 'wave 3s linear infinite' + wave: 'wave 3s linear infinite', }, keyframes: { - "wave": { - "0%, 30%, 100%": { - transform: "initial" + wave: { + '0%, 30%, 100%': { + transform: 'initial', }, - "15%": { - transform: "translateY(-10px)" - } - } + '15%': { + transform: 'translateY(-10px)', + }, + }, }, opacity: { 35: '.35', @@ -221,6 +226,6 @@ module.exports = { in their parent hierarchy */ important: '.tw-style', experimental: { - optimizeUniversalDefaults: true + optimizeUniversalDefaults: true, }, }; From e9b73f7c44479cc783ddc12d59518b5e6e40f1b4 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Mon, 9 Dec 2024 12:13:06 -0800 Subject: [PATCH 24/60] feat(weave): Add AzureOpenAI support for Scorers (#3171) * init * init --- weave/scorers/llm_utils.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/weave/scorers/llm_utils.py b/weave/scorers/llm_utils.py index 8e2672b8ee7..68ae2ccb366 100644 --- a/weave/scorers/llm_utils.py +++ b/weave/scorers/llm_utils.py @@ -23,10 +23,17 @@ from google.generativeai import GenerativeModel from instructor.patch import InstructorChatCompletionCreate from mistralai import Mistral - from openai import AsyncOpenAI, OpenAI + from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI _LLM_CLIENTS = Union[ - OpenAI, AsyncOpenAI, Anthropic, AsyncAnthropic, Mistral, GenerativeModel + OpenAI, + AsyncOpenAI, + AzureOpenAI, + AsyncAzureOpenAI, + Anthropic, + AsyncAnthropic, + Mistral, + GenerativeModel, ] else: _LLM_CLIENTS = object @@ -34,6 +41,8 @@ _LLM_CLIENTS_NAMES = ( "OpenAI", "AsyncOpenAI", + "AzureOpenAI", + "AsyncAzureOpenAI", "Anthropic", "AsyncAnthropic", "Mistral", From 2351ecf4f867a7834262f9b68ed026ba77723b24 Mon Sep 17 00:00:00 2001 From: Jamie Rasmussen <112953339+jamie-rasmussen@users.noreply.github.com> Date: Mon, 9 Dec 2024 15:15:14 -0600 Subject: [PATCH 25/60] chore(weave): add sun and moon icons (#3182) --- weave-js/src/assets/icons/icon-moon.svg | 3 +++ weave-js/src/assets/icons/icon-not-visible.svg | 2 +- weave-js/src/assets/icons/icon-pin-to-right.svg | 2 +- weave-js/src/assets/icons/icon-sun.svg | 11 +++++++++++ weave-js/src/components/Icon/Icon.tsx | 10 ++++++++++ weave-js/src/components/Icon/index.ts | 2 ++ weave-js/src/components/Icon/types.ts | 2 ++ 7 files changed, 30 insertions(+), 2 deletions(-) create mode 100644 weave-js/src/assets/icons/icon-moon.svg create mode 100644 weave-js/src/assets/icons/icon-sun.svg diff --git a/weave-js/src/assets/icons/icon-moon.svg b/weave-js/src/assets/icons/icon-moon.svg new file mode 100644 index 00000000000..e448eab96c3 --- /dev/null +++ b/weave-js/src/assets/icons/icon-moon.svg @@ -0,0 +1,3 @@ + + + diff --git a/weave-js/src/assets/icons/icon-not-visible.svg b/weave-js/src/assets/icons/icon-not-visible.svg index 766810c7811..b2782d987b9 100644 --- a/weave-js/src/assets/icons/icon-not-visible.svg +++ b/weave-js/src/assets/icons/icon-not-visible.svg @@ -2,4 +2,4 @@ - \ No newline at end of file + diff --git a/weave-js/src/assets/icons/icon-pin-to-right.svg b/weave-js/src/assets/icons/icon-pin-to-right.svg index 1ae05ea52ae..46a9c0bf114 100644 --- a/weave-js/src/assets/icons/icon-pin-to-right.svg +++ b/weave-js/src/assets/icons/icon-pin-to-right.svg @@ -2,4 +2,4 @@ - \ No newline at end of file + diff --git a/weave-js/src/assets/icons/icon-sun.svg b/weave-js/src/assets/icons/icon-sun.svg new file mode 100644 index 00000000000..bb0c57891b0 --- /dev/null +++ b/weave-js/src/assets/icons/icon-sun.svg @@ -0,0 +1,11 @@ + + + + + + + + + + + diff --git a/weave-js/src/components/Icon/Icon.tsx b/weave-js/src/components/Icon/Icon.tsx index 948c8a456b4..9c020b1bab9 100644 --- a/weave-js/src/components/Icon/Icon.tsx +++ b/weave-js/src/components/Icon/Icon.tsx @@ -139,6 +139,7 @@ import {ReactComponent as ImportMinimizeMode} from '../../assets/icons/icon-mini import {ReactComponent as ImportModel} from '../../assets/icons/icon-model.svg'; import {ReactComponent as ImportModelOnDark} from '../../assets/icons/icon-model-on-dark.svg'; import {ReactComponent as ImportMolecule} from '../../assets/icons/icon-molecule.svg'; +import {ReactComponent as ImportMoon} from '../../assets/icons/icon-moon.svg'; import {ReactComponent as ImportMusicAudio} from '../../assets/icons/icon-music-audio.svg'; import {ReactComponent as ImportNewSectionAbove} from '../../assets/icons/icon-new-section-above.svg'; import {ReactComponent as ImportNewSectionBelow} from '../../assets/icons/icon-new-section-below.svg'; @@ -216,6 +217,7 @@ import {ReactComponent as ImportStar} from '../../assets/icons/icon-star.svg'; import {ReactComponent as ImportStarFilled} from '../../assets/icons/icon-star-filled.svg'; import {ReactComponent as ImportStop} from '../../assets/icons/icon-stop.svg'; import {ReactComponent as ImportStopped} from '../../assets/icons/icon-stopped.svg'; +import {ReactComponent as ImportSun} from '../../assets/icons/icon-sun.svg'; import {ReactComponent as ImportSwap} from '../../assets/icons/icon-swap.svg'; import {ReactComponent as ImportSweepBayes} from '../../assets/icons/icon-sweep-bayes.svg'; import {ReactComponent as ImportSweepGrid} from '../../assets/icons/icon-sweep-grid.svg'; @@ -695,6 +697,9 @@ export const IconModelOnDark = (props: SVGIconProps) => ( export const IconMolecule = (props: SVGIconProps) => ( ); +export const IconMoon = (props: SVGIconProps) => ( + +); export const IconMusicAudio = (props: SVGIconProps) => ( ); @@ -926,6 +931,9 @@ export const IconStop = (props: SVGIconProps) => ( export const IconStopped = (props: SVGIconProps) => ( ); +export const IconSun = (props: SVGIconProps) => ( + +); export const IconSwap = (props: SVGIconProps) => ( ); @@ -1211,6 +1219,7 @@ const ICON_NAME_TO_ICON: Record = { model: IconModel, 'model-on-dark': IconModelOnDark, molecule: IconMolecule, + moon: IconMoon, 'music-audio': IconMusicAudio, 'new-section-above': IconNewSectionAbove, 'new-section-below': IconNewSectionBelow, @@ -1288,6 +1297,7 @@ const ICON_NAME_TO_ICON: Record = { 'star-filled': IconStarFilled, stop: IconStop, stopped: IconStopped, + sun: IconSun, swap: IconSwap, 'sweep-bayes': IconSweepBayes, 'sweep-grid': IconSweepGrid, diff --git a/weave-js/src/components/Icon/index.ts b/weave-js/src/components/Icon/index.ts index 81839a71019..85ea5332649 100644 --- a/weave-js/src/components/Icon/index.ts +++ b/weave-js/src/components/Icon/index.ts @@ -139,6 +139,7 @@ export { IconModel, IconModelOnDark, IconMolecule, + IconMoon, IconMusicAudio, IconNewSectionAbove, IconNewSectionBelow, @@ -216,6 +217,7 @@ export { IconStarFilled, IconStop, IconStopped, + IconSun, IconSwap, IconSweepBayes, IconSweepGrid, diff --git a/weave-js/src/components/Icon/types.ts b/weave-js/src/components/Icon/types.ts index 55f46c52833..87a1207bc85 100644 --- a/weave-js/src/components/Icon/types.ts +++ b/weave-js/src/components/Icon/types.ts @@ -138,6 +138,7 @@ export const IconNames = { Model: 'model', ModelOnDark: 'model-on-dark', Molecule: 'molecule', + Moon: 'moon', MusicAudio: 'music-audio', NewSectionAbove: 'new-section-above', NewSectionBelow: 'new-section-below', @@ -215,6 +216,7 @@ export const IconNames = { StarFilled: 'star-filled', Stop: 'stop', Stopped: 'stopped', + Sun: 'sun', Swap: 'swap', SweepBayes: 'sweep-bayes', SweepGrid: 'sweep-grid', From 8d0f3f7a6d84108fe52d335251d4080078dff0f0 Mon Sep 17 00:00:00 2001 From: Jamie Rasmussen <112953339+jamie-rasmussen@users.noreply.github.com> Date: Mon, 9 Dec 2024 16:08:01 -0600 Subject: [PATCH 26/60] chore(ui): fix typo (#3183) --- .../Home/Browse3/pages/CallPage/CallPage.tsx | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx index 3e4a74a4885..484b038c193 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx @@ -312,7 +312,7 @@ const CallPageInnerVertical: FC<{ const {rowIdsConfigured} = useContext(TableRowSelectionContext); const {isPeeking} = useContext(WeaveflowPeekContext); - const showPaginationContols = isPeeking && rowIdsConfigured; + const showPaginationControls = isPeeking && rowIdsConfigured; const callTabs = useCallTabs(currentCall); @@ -330,10 +330,10 @@ const CallPageInnerVertical: FC<{ justifyContent: 'space-between', alignItems: 'center', }}> - {showPaginationContols && ( + {showPaginationControls && ( )} - + +
+ +
+
+ + + + ); +}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/PlotlyBarPlot.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/PlotlyBarPlot.tsx index 9706ac09567..7942aea195e 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/PlotlyBarPlot.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/PlotlyBarPlot.tsx @@ -2,50 +2,54 @@ import * as Plotly from 'plotly.js'; import React, {useEffect, useMemo, useRef} from 'react'; import {PLOT_GRID_COLOR} from '../../ecpConstants'; -import {RadarPlotData} from './PlotlyRadarPlot'; export const PlotlyBarPlot: React.FC<{ height: number; - data: RadarPlotData; + yRange: [number, number]; + plotlyData: Plotly.Data; }> = props => { const divRef = useRef(null); - const plotlyData: Plotly.Data[] = useMemo(() => { - return Object.keys(props.data).map((key, i) => { - const {metrics, name, color} = props.data[key]; - return { - type: 'bar', - y: Object.values(metrics), - x: Object.keys(metrics), - name, - marker: {color}, - }; - }); - }, [props.data]); - const plotlyLayout: Partial = useMemo(() => { return { - height: props.height - 40, + height: props.height - 30, showlegend: false, margin: { - l: 0, + l: 20, r: 0, b: 20, - t: 0, - pad: 0, + t: 26, }, + bargap: 0.1, xaxis: { automargin: true, fixedrange: true, gridcolor: PLOT_GRID_COLOR, linecolor: PLOT_GRID_COLOR, + showticklabels: false, }, yaxis: { fixedrange: true, + range: props.yRange, gridcolor: PLOT_GRID_COLOR, linecolor: PLOT_GRID_COLOR, + showticklabels: true, + tickfont: { + size: 10, + }, + }, + title: { + multiline: true, + text: props.plotlyData.name ?? '', + font: {size: 12}, + xref: 'paper', + x: 0.5, + y: 1, + yanchor: 'top', + pad: {t: 2}, }, }; - }, [props.height]); + }, [props.height, props.plotlyData, props.yRange]); + const plotlyConfig = useMemo(() => { return { displayModeBar: false, @@ -57,11 +61,11 @@ export const PlotlyBarPlot: React.FC<{ useEffect(() => { Plotly.newPlot( divRef.current as any, - plotlyData, + [props.plotlyData], plotlyLayout, plotlyConfig ); - }, [plotlyConfig, plotlyData, plotlyLayout]); + }, [plotlyConfig, props.plotlyData, plotlyLayout]); return
; }; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/PlotlyRadarPlot.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/PlotlyRadarPlot.tsx index d459d1354f1..47d0fa3f10c 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/PlotlyRadarPlot.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/PlotlyRadarPlot.tsx @@ -31,13 +31,13 @@ export const PlotlyRadarPlot: React.FC<{ }, [props.data]); const plotlyLayout: Partial = useMemo(() => { return { - height: props.height, + height: props.height - 40, showlegend: false, margin: { - l: 60, - r: 0, + l: 20, + r: 20, b: 30, - t: 30, + t: 20, pad: 0, }, polar: { diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/SummaryPlotsSection.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/SummaryPlotsSection.tsx index 5bfaa8fcb04..c23ffcce04d 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/SummaryPlotsSection.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/SummaryPlotsSection.tsx @@ -1,15 +1,6 @@ import {Box} from '@material-ui/core'; -import {Popover} from '@mui/material'; -import {Switch} from '@wandb/weave/components'; import {Button} from '@wandb/weave/components/Button'; -import { - DraggableGrow, - DraggableHandle, -} from '@wandb/weave/components/DraggablePopups'; -import {TextField} from '@wandb/weave/components/Form/TextField'; import {Tailwind} from '@wandb/weave/components/Tailwind'; -import {maybePluralize} from '@wandb/weave/core/util/string'; -import classNames from 'classnames'; import React, {useEffect, useMemo, useRef, useState} from 'react'; import {buildCompositeMetricsMap} from '../../compositeMetricsUtil'; @@ -27,6 +18,7 @@ import { resolveSummaryMetricValueForEvaluateCall, } from '../../ecpUtil'; import {HorizontalBox, VerticalBox} from '../../Layout'; +import {MetricsSelector} from './MetricsSelector'; import {PlotlyBarPlot} from './PlotlyBarPlot'; import {PlotlyRadarPlot, RadarPlotData} from './PlotlyRadarPlot'; @@ -36,15 +28,12 @@ import {PlotlyRadarPlot, RadarPlotData} from './PlotlyRadarPlot'; export const SummaryPlots: React.FC<{ state: EvaluationComparisonState; setSelectedMetrics: (newModel: Record) => void; -}> = props => { - const {radarData, allMetricNames} = useNormalizedPlotDataFromMetrics( - props.state - ); - const {selectedMetrics} = props.state; - const setSelectedMetrics = props.setSelectedMetrics; +}> = ({state, setSelectedMetrics}) => { + const {radarData, allMetricNames} = useNormalizedPlotDataFromMetrics(state); + const {selectedMetrics} = state; + // Initialize selectedMetrics if null useEffect(() => { - // If selectedMetrics is null, we should show all metrics if (selectedMetrics == null) { setSelectedMetrics( Object.fromEntries(Array.from(allMetricNames).map(m => [m, true])) @@ -52,10 +41,184 @@ export const SummaryPlots: React.FC<{ } }, [selectedMetrics, setSelectedMetrics, allMetricNames]); - // filter down the plotlyRadarData to only include the selected metrics, after - // computation, to allow quick addition/removal of metrics - const filteredPlotlyRadarData = useMemo(() => { - const filteredData: RadarPlotData = {}; + const filteredData = useFilteredData(radarData, selectedMetrics); + const normalizedRadarData = normalizeDataForRadarPlot(filteredData); + const barPlotData = useBarPlotData(filteredData); + + const { + containerRef, + isInitialRender, + plotsPerPage, + currentPage, + setCurrentPage, + } = useContainerDimensions(); + + const {plotsToShow, totalPlots, startIndex, endIndex, totalPages} = + usePaginatedPlots( + normalizedRadarData, + barPlotData, + plotsPerPage, + currentPage + ); + + // Render placeholder during initial render + if (isInitialRender) { + return
; + } + + return ( + + +
+ {plotsToShow} +
+ setCurrentPage(prev => Math.max(prev - 1, 0))} + onNextPage={() => + setCurrentPage(prev => Math.min(prev + 1, totalPages - 1)) + } + /> +
+ ); +}; + +const SectionHeader: React.FC<{ + selectedMetrics: Record | undefined; + setSelectedMetrics: (newModel: Record) => void; + allMetrics: string[]; +}> = ({selectedMetrics, setSelectedMetrics, allMetrics}) => ( + + + Summary Metrics + + +
+
Configure displayed metrics
+ +
+
+
+); + +const RadarPlotBox: React.FC<{data: RadarPlotData}> = ({data}) => ( + + + +); + +const BarPlotBox: React.FC<{ + plot: {plotlyData: Plotly.Data; yRange: [number, number]}; +}> = ({plot}) => ( + + + +); + +const PaginationControls: React.FC<{ + currentPage: number; + totalPages: number; + startIndex: number; + endIndex: number; + totalPlots: number; + onPrevPage: () => void; + onNextPage: () => void; +}> = ({ + currentPage, + totalPages, + startIndex, + endIndex, + totalPlots, + onPrevPage, + onNextPage, +}) => ( + + + +
+
+
+
+
+); + +const useFilteredData = ( + radarData: RadarPlotData, + selectedMetrics: Record | undefined +) => + useMemo(() => { + const data: RadarPlotData = {}; for (const [callId, metricBin] of Object.entries(radarData)) { const metrics: {[metric: string]: number} = {}; for (const [metric, value] of Object.entries(metricBin.metrics)) { @@ -64,253 +227,195 @@ export const SummaryPlots: React.FC<{ } } if (Object.keys(metrics).length > 0) { - filteredData[callId] = { + data[callId] = { metrics, name: metricBin.name, color: metricBin.color, }; } } - return filteredData; + return data; }, [radarData, selectedMetrics]); - return ( - - - - Summary Metrics - - -
-
Configure displayed metrics
- -
-
-
- - - - - - - - -
- ); -}; +function getMetricValuesFromRadarData(radarData: RadarPlotData): { + [metric: string]: number[]; +} { + const metricValues: {[metric: string]: number[]} = {}; + // Gather all values for each metric + Object.values(radarData).forEach(callData => { + Object.entries(callData.metrics).forEach(([metric, value]) => { + if (!metricValues[metric]) { + metricValues[metric] = []; + } + metricValues[metric].push(value); + }); + }); + return metricValues; +} -const MetricsSelector: React.FC<{ - setSelectedMetrics: (newModel: Record) => void; - selectedMetrics: Record | undefined; - allMetrics: string[]; -}> = ({setSelectedMetrics, selectedMetrics, allMetrics}) => { - const [search, setSearch] = useState(''); - - const ref = useRef(null); - const [anchorEl, setAnchorEl] = useState(null); - const onClick = (event: React.MouseEvent) => { - setAnchorEl(anchorEl ? null : ref.current); - setSearch(''); +function getMetricMinsFromRadarData(radarData: RadarPlotData): { + [metric: string]: number; +} { + const metricValues = getMetricValuesFromRadarData(radarData); + const metricMins: {[metric: string]: number} = {}; + Object.entries(metricValues).forEach(([metric, values]) => { + metricMins[metric] = Math.min(...values); + }); + return metricMins; +} + +function normalizeDataForRadarPlot(radarData: RadarPlotData): RadarPlotData { + const metricMins = getMetricMinsFromRadarData(radarData); + + const normalizedData: RadarPlotData = {}; + Object.entries(radarData).forEach(([callId, callData]) => { + normalizedData[callId] = { + name: callData.name, + color: callData.color, + metrics: {}, + }; + + Object.entries(callData.metrics).forEach(([metric, value]) => { + const min = metricMins[metric]; + // Only shift values if there are negative values + const normalizedValue = min < 0 ? value - min : value; + normalizedData[callId].metrics[metric] = normalizedValue; + }); + }); + + return normalizedData; +} + +const useBarPlotData = (filteredData: RadarPlotData) => + useMemo(() => { + const metrics: { + [metric: string]: { + callIds: string[]; + values: number[]; + name: string; + colors: string[]; + }; + } = {}; + + // Reorganize data by metric instead of by call + for (const [callId, metricBin] of Object.entries(filteredData)) { + for (const [metric, value] of Object.entries(metricBin.metrics)) { + if (!metrics[metric]) { + metrics[metric] = {callIds: [], values: [], name: metric, colors: []}; + } + metrics[metric].callIds.push(callId); + metrics[metric].values.push(value); + metrics[metric].colors.push(metricBin.color); + } + } + + // Convert metrics object to Plotly data format + return Object.entries(metrics).map(([metric, metricBin]) => { + const maxY = Math.max(...metricBin.values) * 1.1; + const minY = Math.min(...metricBin.values, 0); + const plotlyData: Plotly.Data = { + type: 'bar', + y: metricBin.values, + x: metricBin.callIds, + text: metricBin.values.map(value => value.toFixed(3)), + textposition: 'outside', + textfont: {size: 14, color: 'black'}, + name: metric, + marker: {color: metricBin.colors}, + }; + return {plotlyData, yRange: [minY, maxY] as [number, number]}; + }); + }, [filteredData]); + +const useContainerDimensions = () => { + const containerRef = useRef(null); + const [containerWidth, setContainerWidth] = useState(0); + const [isInitialRender, setIsInitialRender] = useState(true); + const [currentPage, setCurrentPage] = useState(0); + + useEffect(() => { + const updateWidth = () => { + if (containerRef.current) { + setContainerWidth(containerRef.current.offsetWidth); + } + }; + + updateWidth(); + setIsInitialRender(false); + + window.addEventListener('resize', updateWidth); + return () => window.removeEventListener('resize', updateWidth); + }, []); + + const plotsPerPage = useMemo(() => { + return Math.max(1, Math.floor(containerWidth / PLOT_HEIGHT)); + }, [containerWidth]); + + return { + containerRef, + isInitialRender, + plotsPerPage, + currentPage, + setCurrentPage, }; - const open = Boolean(anchorEl); - const id = open ? 'simple-popper' : undefined; +}; - const filteredCols = search - ? allMetrics.filter(col => col.toLowerCase().includes(search.toLowerCase())) - : allMetrics; +const usePaginatedPlots = ( + filteredData: RadarPlotData, + barPlotData: Array<{plotlyData: Plotly.Data; yRange: [number, number]}>, + plotsPerPage: number, + currentPage: number +) => { + const radarPlotWidth = 2; + const totalBarPlots = barPlotData.length; + const totalPlotWidth = radarPlotWidth + totalBarPlots; + const totalPages = Math.ceil(totalPlotWidth / plotsPerPage); - const shownMetrics = Object.values(selectedMetrics ?? {}).filter(Boolean); + const plotsToShow = useMemo(() => { + // First page always shows radar plot + if (currentPage === 0) { + const availableSpace = plotsPerPage - radarPlotWidth; + return [ + , + ...barPlotData + .slice(0, availableSpace) + .map((plot, index) => ( + + )), + ]; + } else { + // Subsequent pages show only bar plots + const startIdx = + (currentPage - 1) * plotsPerPage + (plotsPerPage - radarPlotWidth); + const endIdx = startIdx + plotsPerPage; + return barPlotData + .slice(startIdx, endIdx) + .map((plot, index) => ( + + )); + } + }, [currentPage, plotsPerPage, filteredData, barPlotData]); - const numHidden = allMetrics.length - shownMetrics.length; - const buttonSuffix = search ? `(${filteredCols.length})` : 'all'; + // Calculate pagination details + const totalPlots = barPlotData.length + 1; // +1 for the radar plot + const startIndex = + currentPage === 0 ? 1 : Math.min(plotsPerPage + 1, totalPlots); + const endIndex = + currentPage === 0 + ? Math.min(plotsToShow.length, totalPlots) + : Math.min(startIndex + plotsToShow.length - 1, totalPlots); - return ( - <> - - -
- -
-
- - - - ); + return {plotsToShow, totalPlots, startIndex, endIndex, totalPages}; }; -const normalizeValues = (values: Array): number[] => { +function normalizeValues(values: Array): number[] { // find the max value // find the power of 2 that is greater than the max value // divide all values by that power of 2 const maxVal = Math.max(...(values.filter(v => v !== undefined) as number[])); const maxPower = Math.ceil(Math.log2(maxVal)); return values.map(val => (val ? val / 2 ** maxPower : 0)); -}; +} const useNormalizedPlotDataFromMetrics = ( state: EvaluationComparisonState From a7170b4c1e6dbca9b279865b588d4edfee34468f Mon Sep 17 00:00:00 2001 From: Jamie Rasmussen <112953339+jamie-rasmussen@users.noreply.github.com> Date: Tue, 10 Dec 2024 11:00:13 -0600 Subject: [PATCH 28/60] feat(ui): Support Anthropic calls in Chat View (#3187) --- .../Home/Browse3/pages/ChatView/hooks.ts | 88 +++++++++++++++++++ 1 file changed, 88 insertions(+) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/hooks.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/hooks.ts index 38b5c820195..33ced58ec49 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/hooks.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/hooks.ts @@ -246,6 +246,63 @@ export const isTraceCallChatFormatGemini = (call: TraceCallSchema): boolean => { ); }; +export const isAnthropicContentBlock = (block: any): boolean => { + if (!_.isPlainObject(block)) { + return false; + } + // TODO: Are there other types? + if (block.type !== 'text') { + return false; + } + if (!hasStringProp(block, 'text')) { + return false; + } + return true; +}; + +export const isAnthropicCompletionFormat = (output: any): boolean => { + if (output !== null) { + // TODO: Could have additional checks here on things like usage + if ( + _.isPlainObject(output) && + output.type === 'message' && + output.role === 'assistant' && + hasStringProp(output, 'model') && + _.isArray(output.content) && + output.content.every((c: any) => isAnthropicContentBlock(c)) + ) { + return true; + } + return false; + } + return true; +}; + +type AnthropicContentBlock = { + type: 'text'; + text: string; +}; + +export const anthropicContentBlocksToChoices = ( + blocks: AnthropicContentBlock[], + stopReason: string +): Choice[] => { + const choices: Choice[] = []; + for (let i = 0; i < blocks.length; i++) { + const block = blocks[i]; + choices.push({ + index: i, + message: { + role: 'assistant', + content: block.text, + }, + // TODO: What is correct way to map this? + finish_reason: stopReason, + }); + } + return choices; +}; + export const isTraceCallChatFormatOpenAI = (call: TraceCallSchema): boolean => { if (!('messages' in call.inputs)) { return false; @@ -336,6 +393,19 @@ export const normalizeChatRequest = (request: any): ChatRequest => { ], }; } + // Anthropic has system message as a top-level request field + if (hasStringProp(request, 'system')) { + return { + ...request, + messages: [ + { + role: 'system', + content: request.system, + }, + ...request.messages, + ], + }; + } return request as ChatRequest; }; @@ -360,6 +430,24 @@ export const normalizeChatCompletion = ( }, }; } + if (isAnthropicCompletionFormat(completion)) { + return { + id: completion.id, + choices: anthropicContentBlocksToChoices( + completion.content, + completion.stop_reason + ), + created: 0, + model: completion.model, + system_fingerprint: '', + usage: { + prompt_tokens: completion.usage.input_tokens, + completion_tokens: completion.usage.output_tokens, + total_tokens: + completion.usage.input_tokens + completion.usage.output_tokens, + }, + }; + } return completion as ChatCompletion; }; From 935f68cbfdd55a2f65a476b201d2227167416097 Mon Sep 17 00:00:00 2001 From: Connie Lee Date: Tue, 10 Dec 2024 09:28:08 -0800 Subject: [PATCH 29/60] feat(ui): Add ExpandingPill component (#3184) --- weave-js/src/components/Tag/Pill.tsx | 39 ++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/weave-js/src/components/Tag/Pill.tsx b/weave-js/src/components/Tag/Pill.tsx index 6f734f9cbfc..ed958a0c543 100644 --- a/weave-js/src/components/Tag/Pill.tsx +++ b/weave-js/src/components/Tag/Pill.tsx @@ -59,3 +59,42 @@ export const IconOnlyPill: FC = ({ ); }; + +export type ExpandingPillProps = { + className?: string; + color?: TagColorName; + icon: IconName; + label: string; +}; +export const ExpandingPill = ({ + className, + color, + icon, + label, +}: ExpandingPillProps) => { + const classes = useTagClasses({color, isInteractive: true}); + return ( + +
+ + + {label} + +
+
+ ); +}; From a14edbe6f6c37f3e9094b50ffdd2adc5bd3ecb65 Mon Sep 17 00:00:00 2001 From: Jamie Rasmussen <112953339+jamie-rasmussen@users.noreply.github.com> Date: Tue, 10 Dec 2024 11:52:04 -0600 Subject: [PATCH 30/60] chore(weave): ui_url error gives a hint about the problem (#3190) --- weave/trace/weave_client.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/weave/trace/weave_client.py b/weave/trace/weave_client.py index 86ba7b8653e..0eca3fcbedb 100644 --- a/weave/trace/weave_client.py +++ b/weave/trace/weave_client.py @@ -239,7 +239,9 @@ def func_name(self) -> str: @property def feedback(self) -> RefFeedbackQuery: if not self.id: - raise ValueError("Can't get feedback for call without ID") + raise ValueError( + "Can't get feedback for call without ID, was `weave.init` called?" + ) if self._feedback is None: try: @@ -253,7 +255,9 @@ def feedback(self) -> RefFeedbackQuery: @property def ui_url(self) -> str: if not self.id: - raise ValueError("Can't get URL for call without ID") + raise ValueError( + "Can't get URL for call without ID, was `weave.init` called?" + ) try: entity, project = self.project_id.split("/") @@ -265,7 +269,9 @@ def ui_url(self) -> str: def ref(self) -> CallRef: entity, project = self.project_id.split("/") if not self.id: - raise ValueError("Can't get ref for call without ID") + raise ValueError( + "Can't get ref for call without ID, was `weave.init` called?" + ) return CallRef(entity, project, self.id) @@ -273,7 +279,9 @@ def ref(self) -> CallRef: def children(self) -> CallsIter: client = weave_client_context.require_weave_client() if not self.id: - raise ValueError("Can't get children of call without ID") + raise ValueError( + "Can't get children of call without ID, was `weave.init` called?" + ) client = weave_client_context.require_weave_client() return CallsIter( From c19fe6bf246c058a3cbaa11bd544eb2e2fbbd24c Mon Sep 17 00:00:00 2001 From: brianlund-wandb Date: Tue, 10 Dec 2024 10:36:51 -0800 Subject: [PATCH 31/60] fix(app): render bounding boxes on media panel without points (#3158) * WB-16623 coercing empty point to default, targeting camera on bounding box when point cloud empty * code style cleanup --- .../src/common/util/SdkPointCloudToBabylon.test.ts | 14 ++++++++++++++ weave-js/src/common/util/SdkPointCloudToBabylon.ts | 2 +- weave-js/src/common/util/render_babylon.ts | 13 +++++++++++-- 3 files changed, 26 insertions(+), 3 deletions(-) diff --git a/weave-js/src/common/util/SdkPointCloudToBabylon.test.ts b/weave-js/src/common/util/SdkPointCloudToBabylon.test.ts index 95c7639dc4a..df5eca57b46 100644 --- a/weave-js/src/common/util/SdkPointCloudToBabylon.test.ts +++ b/weave-js/src/common/util/SdkPointCloudToBabylon.test.ts @@ -4,6 +4,7 @@ import { DEFAULT_POINT_COLOR, getFilteringOptionsForPointCloud, getVertexCompatiblePositionsAndColors, + loadPointCloud, MAX_BOUNDING_BOX_LABELS_FOR_DISPLAY, MaxAlphaValue, } from './SdkPointCloudToBabylon'; @@ -174,3 +175,16 @@ describe('getFilteringOptionsForPointCloud', () => { expect(newClassIdToLabel.get(49)).toEqual('label49'); }); }); +describe('loadPointCloud', () => { + it('appropriate defaults set when loading point cloud from file', () => { + const fileContents = JSON.stringify({ + boxes: [], + points: [[]], + type: 'lidar/beta', + vectors: [], + }); + const babylonPointCloud = loadPointCloud(fileContents); + expect(babylonPointCloud.points).toHaveLength(1); + expect(babylonPointCloud.points[0].position).toEqual([0, 0, 0]); + }); +}); diff --git a/weave-js/src/common/util/SdkPointCloudToBabylon.ts b/weave-js/src/common/util/SdkPointCloudToBabylon.ts index 274e1676be4..d52682743ee 100644 --- a/weave-js/src/common/util/SdkPointCloudToBabylon.ts +++ b/weave-js/src/common/util/SdkPointCloudToBabylon.ts @@ -160,7 +160,7 @@ export const handlePoints = (object3D: Object3DScene): ScenePoint[] => { // Draw Points return truncatedPoints.map(point => { const [x, y, z, r, g, b] = point; - const position: Position = [x, y, z]; + const position: Position = [x ?? 0, y ?? 0, z ?? 0]; const category = r; if (r !== undefined && g !== undefined && b !== undefined) { diff --git a/weave-js/src/common/util/render_babylon.ts b/weave-js/src/common/util/render_babylon.ts index 10aee3f6c51..ebd213c2677 100644 --- a/weave-js/src/common/util/render_babylon.ts +++ b/weave-js/src/common/util/render_babylon.ts @@ -394,6 +394,15 @@ const pointCloudScene = ( // Apply vertexData to custom mesh vertexData.applyToMesh(pcMesh); + // A file without any points defined still includes a single, empty "point". + // In order to play nice with Babylon, we position this empty point at 0,0,0. + // Hence, a pointCloud with a single point at 0,0,0 is likely empty. + const isEmpty = + pointCloud.points.length === 1 && + pointCloud.points[0].position[0] === 0 && + pointCloud.points[0].position[1] === 0 && + pointCloud.points[0].position[2] === 0; + camera.parent = pcMesh; const pcMaterial = new Babylon.StandardMaterial('mat', scene); @@ -472,8 +481,8 @@ const pointCloudScene = ( new Vector3(edges.length * 2, edges.length * 2, edges.length * 2) ); - // If we are iterating over camera, target a box - if (index === meta?.cameraIndex) { + // If we are iterating over camera or the cloud is empty, target a box + if (index === meta?.cameraIndex || (index === 0 && isEmpty)) { camera.position = center.add(new Vector3(0, 0, 1000)); camera.target = center; camera.zoomOn([lines]); From 0966c9fb21d00fa6b3ed9019de2c4c471fc6ec30 Mon Sep 17 00:00:00 2001 From: Griffin Tarpenning Date: Tue, 10 Dec 2024 12:08:50 -0800 Subject: [PATCH 32/60] fix(ui): feedback header when empty (#3191) --- .../Browse3/feedback/StructuredFeedback/FeedbackSidebar.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/StructuredFeedback/FeedbackSidebar.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/StructuredFeedback/FeedbackSidebar.tsx index ef6bcbd69ff..0b3c9603fef 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/StructuredFeedback/FeedbackSidebar.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/StructuredFeedback/FeedbackSidebar.tsx @@ -96,7 +96,7 @@ export const FeedbackSidebar = ({
Feedback
-
+
{humanAnnotationSpecs.length > 0 ? ( <>
From 35a78d6e83006ea4462d48485651213a4afbd381 Mon Sep 17 00:00:00 2001 From: Griffin Tarpenning Date: Tue, 10 Dec 2024 16:20:17 -0800 Subject: [PATCH 33/60] chore(ui): make eval section header draggable (#2637) --- .../CompareEvaluationsPage.tsx | 14 +-- .../compareEvaluationsContext.tsx | 25 ++-- .../pages/CompareEvaluationsPage/ecpState.ts | 63 +++++++--- .../pages/CompareEvaluationsPage/ecpTypes.ts | 1 + .../ComparisonDefinitionSection.tsx | 110 ++++++++++++------ .../EvaluationDefinition.tsx | 83 +------------ .../DraggableSection/DraggableItem.tsx | 103 ++++++++++++++++ .../DraggableSection/DraggableSection.tsx | 34 ++++++ .../ExampleCompareSection.tsx | 2 +- .../ExampleFilterSection.tsx | 4 +- .../ScorecardSection/ScorecardSection.tsx | 3 +- 11 files changed, 280 insertions(+), 162 deletions(-) create mode 100644 weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/DraggableSection/DraggableItem.tsx create mode 100644 weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/DraggableSection/DraggableSection.tsx diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/CompareEvaluationsPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/CompareEvaluationsPage.tsx index b302b0262ae..478c4887546 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/CompareEvaluationsPage.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/CompareEvaluationsPage.tsx @@ -77,8 +77,6 @@ export const CompareEvaluationsPage: React.FC< export const CompareEvaluationsPageContent: React.FC< CompareEvaluationsPageProps > = props => { - const [baselineEvaluationCallId, setBaselineEvaluationCallId] = - React.useState(null); const [comparisonDimensions, setComparisonDimensions] = React.useState(null); @@ -104,14 +102,6 @@ export const CompareEvaluationsPageContent: React.FC< [comparisonDimensions] ); - React.useEffect(() => { - // Only update the baseline if we are switching evaluations, if there - // is more than 1, we are in the compare view and baseline is auto set - if (props.evaluationCallIds.length === 1) { - setBaselineEvaluationCallId(props.evaluationCallIds[0]); - } - }, [props.evaluationCallIds]); - if (props.evaluationCallIds.length === 0) { return
No evaluations to compare
; } @@ -120,13 +110,11 @@ export const CompareEvaluationsPageContent: React.FC< diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/compareEvaluationsContext.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/compareEvaluationsContext.tsx index b65658a890a..638565e8ad6 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/compareEvaluationsContext.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/compareEvaluationsContext.tsx @@ -10,9 +10,6 @@ import {ComparisonDimensionsType} from './ecpState'; const CompareEvaluationsContext = React.createContext<{ state: EvaluationComparisonState; - setBaselineEvaluationCallId: React.Dispatch< - React.SetStateAction - >; setComparisonDimensions: React.Dispatch< React.SetStateAction >; @@ -20,6 +17,7 @@ const CompareEvaluationsContext = React.createContext<{ setSelectedMetrics: (newModel: Record) => void; addEvaluationCall: (newCallId: string) => void; removeEvaluationCall: (callId: string) => void; + setEvaluationCallOrder: (newCallIdOrder: string[]) => void; } | null>(null); export const useCompareEvaluationsState = () => { @@ -33,34 +31,26 @@ export const useCompareEvaluationsState = () => { export const CompareEvaluationsProvider: React.FC<{ entity: string; project: string; + initialEvaluationCallIds: string[]; selectedMetrics: Record | null; setSelectedMetrics: (newModel: Record) => void; - initialEvaluationCallIds: string[]; onEvaluationCallIdsUpdate: (newEvaluationCallIds: string[]) => void; - setBaselineEvaluationCallId: React.Dispatch< - React.SetStateAction - >; setComparisonDimensions: React.Dispatch< React.SetStateAction >; setSelectedInputDigest: React.Dispatch>; - baselineEvaluationCallId?: string; comparisonDimensions?: ComparisonDimensionsType; selectedInputDigest?: string; }> = ({ entity, project, + initialEvaluationCallIds, selectedMetrics, setSelectedMetrics, - - initialEvaluationCallIds, onEvaluationCallIdsUpdate, - setBaselineEvaluationCallId, setComparisonDimensions, - setSelectedInputDigest, - baselineEvaluationCallId, comparisonDimensions, selectedInputDigest, children, @@ -77,7 +67,6 @@ export const CompareEvaluationsProvider: React.FC<{ entity, project, evaluationCallIds, - baselineEvaluationCallId, comparisonDimensions, selectedInputDigest, selectedMetrics ?? undefined @@ -89,7 +78,6 @@ export const CompareEvaluationsProvider: React.FC<{ } return { state: initialState.result, - setBaselineEvaluationCallId, setComparisonDimensions, setSelectedInputDigest, setSelectedMetrics, @@ -105,14 +93,17 @@ export const CompareEvaluationsProvider: React.FC<{ setEvaluationCallIds(newEvaluationCallIds); onEvaluationCallIdsUpdate(newEvaluationCallIds); }, + setEvaluationCallOrder: (newCallIdOrder: string[]) => { + setEvaluationCallIds(newCallIdOrder); + onEvaluationCallIdsUpdate(newCallIdOrder); + }, }; }, [ initialState.loading, initialState.result, + setEvaluationCallIds, evaluationCallIds, onEvaluationCallIdsUpdate, - setEvaluationCallIds, - setBaselineEvaluationCallId, setComparisonDimensions, setSelectedInputDigest, setSelectedMetrics, diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpState.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpState.ts index 35f95dbf14f..e5c1b03d60a 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpState.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpState.ts @@ -18,14 +18,14 @@ import {getMetricIds} from './ecpUtil'; export type EvaluationComparisonState = { // The normalized data for the evaluations data: EvaluationComparisonData; - // The evaluation call id of the baseline model - baselineEvaluationCallId: string; // The dimensions to compare & filter results comparisonDimensions?: ComparisonDimensionsType; // The current digest which is in view selectedInputDigest?: string; // The selected metrics to display selectedMetrics?: Record; + // Ordered call Ids + evaluationCallIdsOrdered: string[]; }; export type ComparisonDimensionsType = Array<{ @@ -43,12 +43,14 @@ export const useEvaluationComparisonState = ( entity: string, project: string, evaluationCallIds: string[], - baselineEvaluationCallId?: string, comparisonDimensions?: ComparisonDimensionsType, selectedInputDigest?: string, selectedMetrics?: Record ): Loadable => { - const data = useEvaluationComparisonData(entity, project, evaluationCallIds); + const orderedCallIds = useMemo(() => { + return getCallIdsOrderedForQuery(evaluationCallIds); + }, [evaluationCallIds]); + const data = useEvaluationComparisonData(entity, project, orderedCallIds); const value = useMemo(() => { if (data.result == null || data.loading) { @@ -92,42 +94,45 @@ export const useEvaluationComparisonState = ( loading: false, result: { data: data.result, - baselineEvaluationCallId: - baselineEvaluationCallId ?? evaluationCallIds[0], comparisonDimensions: newComparisonDimensions, selectedInputDigest, selectedMetrics, + evaluationCallIdsOrdered: evaluationCallIds, }, }; }, [ data.result, data.loading, - baselineEvaluationCallId, - evaluationCallIds, comparisonDimensions, selectedInputDigest, selectedMetrics, + evaluationCallIds, ]); return value; }; +export const getOrderedCallIds = (state: EvaluationComparisonState) => { + return Array.from(state.evaluationCallIdsOrdered); +}; + +export const getBaselineCallId = (state: EvaluationComparisonState) => { + return getOrderedCallIds(state)[0]; +}; + /** - * Should use this over keys of `state.data.evaluationCalls` because it ensures the baseline - * evaluation call is first. + * Sort call IDs to ensure consistent order for memoized query params */ -export const getOrderedCallIds = (state: EvaluationComparisonState) => { - const initial = Object.keys(state.data.evaluationCalls); - moveItemToFront(initial, state.baselineEvaluationCallId); - return initial; +const getCallIdsOrderedForQuery = (callIds: string[]) => { + return Array.from(callIds).sort(); }; /** * Should use this over keys of `state.data.models` because it ensures the baseline model is first. */ export const getOrderedModelRefs = (state: EvaluationComparisonState) => { - const baselineRef = - state.data.evaluationCalls[state.baselineEvaluationCallId].modelRef; + const baselineCallId = getBaselineCallId(state); + const baselineRef = state.data.evaluationCalls[baselineCallId].modelRef; const refs = Object.keys(state.data.models); // Make sure the baseline model is first moveItemToFront(refs, baselineRef); @@ -145,3 +150,29 @@ const moveItemToFront = (arr: T[], item: T): T[] => { } return arr; }; + +export const getOrderedEvalsWithNewBaseline = ( + evaluationCallIds: string[], + newBaselineCallId: string +) => { + return moveItemToFront(evaluationCallIds, newBaselineCallId); +}; + +export const swapEvaluationCalls = ( + evaluationCallIds: string[], + ndx1: number, + ndx2: number +): string[] => { + return swapArrayItems(evaluationCallIds, ndx1, ndx2); +}; + +const swapArrayItems = (arr: T[], ndx1: number, ndx2: number): T[] => { + if (ndx1 < 0 || ndx2 < 0 || ndx1 >= arr.length || ndx2 >= arr.length) { + throw new Error('Index out of bounds'); + } + const newArr = [...arr]; + const from = newArr[ndx1]; + newArr[ndx1] = newArr[ndx2]; + newArr[ndx2] = from; + return newArr; +}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpTypes.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpTypes.ts index 7454e0707b4..b4642fae240 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpTypes.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpTypes.ts @@ -18,6 +18,7 @@ export type EvaluationComparisonData = { }; // EvaluationCalls are the specific calls of an evaluation. + // The visual order of the evaluation calls is determined by the order of the keys. evaluationCalls: { [callId: string]: EvaluationCall; }; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ComparisonDefinitionSection/ComparisonDefinitionSection.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ComparisonDefinitionSection/ComparisonDefinitionSection.tsx index 66ce56c4cb0..2704a66cbea 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ComparisonDefinitionSection/ComparisonDefinitionSection.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ComparisonDefinitionSection/ComparisonDefinitionSection.tsx @@ -12,44 +12,81 @@ import { import {useCallsForQuery} from '../../../CallsPage/callsTableQuery'; import {useEvaluationsFilter} from '../../../CallsPage/evaluationsFilter'; import {Id} from '../../../common/Id'; +import {opNiceName} from '../../../common/Links'; import {useWFHooks} from '../../../wfReactInterface/context'; import { CallSchema, ObjectVersionKey, } from '../../../wfReactInterface/wfDataModelHooksInterface'; import {useCompareEvaluationsState} from '../../compareEvaluationsContext'; -import {STANDARD_PADDING} from '../../ecpConstants'; -import {getOrderedCallIds} from '../../ecpState'; -import {EvaluationComparisonState} from '../../ecpState'; +import { + EvaluationComparisonState, + getOrderedCallIds, + getOrderedEvalsWithNewBaseline, + swapEvaluationCalls, +} from '../../ecpState'; import {HorizontalBox} from '../../Layout'; -import {EvaluationDefinition, VerticalBar} from './EvaluationDefinition'; +import {ItemDef} from '../DraggableSection/DraggableItem'; +import {DraggableSection} from '../DraggableSection/DraggableSection'; +import {VerticalBar} from './EvaluationDefinition'; export const ComparisonDefinitionSection: React.FC<{ state: EvaluationComparisonState; }> = props => { - const evalCallIds = useMemo( - () => getOrderedCallIds(props.state), - [props.state] - ); + const {setEvaluationCallOrder, removeEvaluationCall} = + useCompareEvaluationsState(); + + const callIds = useMemo(() => { + return getOrderedCallIds(props.state); + }, [props.state]); + + const items: ItemDef[] = useMemo(() => { + return callIds.map(callId => ({ + key: 'evaluations', + value: callId, + label: props.state.data.evaluationCalls[callId]?.name ?? callId, + })); + }, [callIds, props.state.data.evaluationCalls]); + + const onSetBaseline = (value: string | null) => { + if (!value) { + return; + } + const newSortOrder = getOrderedEvalsWithNewBaseline(callIds, value); + setEvaluationCallOrder(newSortOrder); + }; + const onRemoveItem = (value: string) => removeEvaluationCall(value); + const onSortEnd = ({ + oldIndex, + newIndex, + }: { + oldIndex: number; + newIndex: number; + }) => { + const newSortOrder = swapEvaluationCalls(callIds, oldIndex, newIndex); + setEvaluationCallOrder(newSortOrder); + }; return ( - - {evalCallIds.map((key, ndx) => { - return ( - - - - ); - })} - - + +
+ + + + + + +
+
); }; @@ -81,7 +118,7 @@ const ModelRefLabel: React.FC<{modelRef: string}> = props => { const objectVersion = useObjectVersion(objVersionKey); return ( - {objectVersion.result?.objectId}:{objectVersion.result?.versionIndex} + {objectVersion.result?.objectId}:v{objectVersion.result?.versionIndex} ); }; @@ -119,10 +156,9 @@ const AddEvaluationButton: React.FC<{ const evalsNotComparing = useMemo(() => { return calls.result.filter( - call => - !Object.keys(props.state.data.evaluationCalls).includes(call.callId) + call => !getOrderedCallIds(props.state).includes(call.callId) ); - }, [calls.result, props.state.data.evaluationCalls]); + }, [calls.result, props.state]); const [menuOptions, setMenuOptions] = useState(evalsNotComparing); @@ -222,12 +258,18 @@ const AddEvaluationButton: React.FC<{ variant="ghost" size="small" className="pb-8 pt-8 font-['Source_Sans_Pro'] text-base font-normal text-moon-800" - onClick={() => { - addEvaluationCall(call.callId); - }}> + onClick={() => addEvaluationCall(call.callId)}> <> - {call.displayName ?? call.spanName} - + + {call.displayName ?? opNiceName(call.spanName)} + + + + diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ComparisonDefinitionSection/EvaluationDefinition.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ComparisonDefinitionSection/EvaluationDefinition.tsx index adc80d044f6..5dcf835e378 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ComparisonDefinitionSection/EvaluationDefinition.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ComparisonDefinitionSection/EvaluationDefinition.tsx @@ -1,8 +1,5 @@ import {Box} from '@material-ui/core'; import {Circle} from '@mui/icons-material'; -import {PopupDropdown} from '@wandb/weave/common/components/PopupDropdown'; -import {Button} from '@wandb/weave/components/Button'; -import {Pill} from '@wandb/weave/components/Tag'; import React, {useMemo} from 'react'; import { @@ -17,87 +14,17 @@ import {SmallRef} from '../../../../../Browse2/SmallRef'; import {CallLink, ObjectVersionLink} from '../../../common/Links'; import {useWFHooks} from '../../../wfReactInterface/context'; import {ObjectVersionKey} from '../../../wfReactInterface/wfDataModelHooksInterface'; -import {useCompareEvaluationsState} from '../../compareEvaluationsContext'; -import { - BOX_RADIUS, - CIRCLE_SIZE, - EVAL_DEF_HEIGHT, - STANDARD_BORDER, -} from '../../ecpConstants'; +import {CIRCLE_SIZE} from '../../ecpConstants'; import {EvaluationComparisonState} from '../../ecpState'; -import {HorizontalBox} from '../../Layout'; - -export const EvaluationDefinition: React.FC<{ - state: EvaluationComparisonState; - callId: string; -}> = props => { - const {removeEvaluationCall, setBaselineEvaluationCallId} = - useCompareEvaluationsState(); - - const menuOptions = useMemo(() => { - return [ - { - key: 'add-to-baseline', - content: 'Set as baseline', - onClick: () => { - setBaselineEvaluationCallId(props.callId); - }, - disabled: props.callId === props.state.baselineEvaluationCallId, - }, - { - key: 'remove', - content: 'Remove', - onClick: () => { - removeEvaluationCall(props.callId); - }, - disabled: Object.keys(props.state.data.evaluationCalls).length === 1, - }, - ]; - }, [ - props.callId, - props.state.baselineEvaluationCallId, - props.state.data.evaluationCalls, - removeEvaluationCall, - setBaselineEvaluationCallId, - ]); - - return ( - - - {props.callId === props.state.baselineEvaluationCallId && ( - - )} -
- - } - /> -
-
- ); -}; export const EvaluationCallLink: React.FC<{ callId: string; state: EvaluationComparisonState; }> = props => { - const evaluationCall = props.state.data.evaluationCalls[props.callId]; + const evaluationCall = props.state.data.evaluationCalls?.[props.callId]; + if (!evaluationCall) { + return null; + } const {entity, project} = props.state.data; return ( diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/DraggableSection/DraggableItem.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/DraggableSection/DraggableItem.tsx new file mode 100644 index 00000000000..1510c502b99 --- /dev/null +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/DraggableSection/DraggableItem.tsx @@ -0,0 +1,103 @@ +import {Button} from '@wandb/weave/components/Button'; +import * as DropdownMenu from '@wandb/weave/components/DropdownMenu'; +import {Icon} from '@wandb/weave/components/Icon'; +import {Pill} from '@wandb/weave/components/Tag/Pill'; +import {Tailwind} from '@wandb/weave/components/Tailwind'; +import classNames from 'classnames'; +import React, {useState} from 'react'; +import {SortableElement, SortableHandle} from 'react-sortable-hoc'; + +import {EvaluationComparisonState} from '../../ecpState'; +import {EvaluationCallLink} from '../ComparisonDefinitionSection/EvaluationDefinition'; + +export type ItemDef = { + key: string; + value: string; + label?: string; +}; + +type DraggableItemProps = { + state: EvaluationComparisonState; + item: ItemDef; + numItems: number; + idx: number; + onRemoveItem: (value: string) => void; + onSetBaseline: (value: string | null) => void; +}; + +export const DraggableItem = SortableElement( + ({ + state, + item, + numItems, + idx, + onRemoveItem, + onSetBaseline, + }: DraggableItemProps) => { + const isDeletable = numItems > 1; + const isBaseline = idx === 0; + const [isOpen, setIsOpen] = useState(false); + + const onMakeBaselinePropagated = (e: React.MouseEvent) => { + e.stopPropagation(); + onSetBaseline(item.value); + }; + + const onRemoveItemPropagated = (e: React.MouseEvent) => { + e.stopPropagation(); + onRemoveItem(item.value); + }; + + return ( + +
+ +
+ + {isBaseline && ( + + )} +
+ + +
+
+ ); + } +); + +const DragHandle = SortableHandle(() => ( +
+ +
+)); diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/DraggableSection/DraggableSection.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/DraggableSection/DraggableSection.tsx new file mode 100644 index 00000000000..23a03ceb5b3 --- /dev/null +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/DraggableSection/DraggableSection.tsx @@ -0,0 +1,34 @@ +import React from 'react'; +import {SortableContainer} from 'react-sortable-hoc'; + +import {EvaluationComparisonState} from '../../ecpState'; +import {DraggableItem} from './DraggableItem'; +import {ItemDef} from './DraggableItem'; + +type DraggableSectionProps = { + state: EvaluationComparisonState; + items: ItemDef[]; + onSetBaseline: (value: string | null) => void; + onRemoveItem: (value: string) => void; +}; + +export const DraggableSection = SortableContainer( + ({state, items, onSetBaseline, onRemoveItem}: DraggableSectionProps) => { + return ( +
+ {items.map((item, index) => ( + + ))} +
+ ); + } +); diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleCompareSection/ExampleCompareSection.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleCompareSection/ExampleCompareSection.tsx index 6041492b5c5..398f65ecd45 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleCompareSection/ExampleCompareSection.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleCompareSection/ExampleCompareSection.tsx @@ -149,7 +149,7 @@ const stickySidebarHeaderMixin: React.CSSProperties = { /** * This component will occupy the entire space provided by the parent container. - * It is intended to be used in teh CompareEvaluations page, as it depends on + * It is intended to be used in the CompareEvaluations page, as it depends on * the EvaluationComparisonState. However, in principle, it is a general purpose * model-output comparison tool. It allows the user to view inputs, then compare * model outputs and evaluation metrics across multiple trials. diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleFilterSection/ExampleFilterSection.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleFilterSection/ExampleFilterSection.tsx index b2c8773a7d1..1146c8ea960 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleFilterSection/ExampleFilterSection.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleFilterSection/ExampleFilterSection.tsx @@ -13,7 +13,7 @@ import { } from '../../compositeMetricsUtil'; import {PLOT_HEIGHT, STANDARD_PADDING} from '../../ecpConstants'; import {MAX_PLOT_DOT_SIZE, MIN_PLOT_DOT_SIZE} from '../../ecpConstants'; -import {EvaluationComparisonState} from '../../ecpState'; +import {EvaluationComparisonState, getBaselineCallId} from '../../ecpState'; import {metricDefinitionId} from '../../ecpUtil'; import { flattenedDimensionPath, @@ -103,7 +103,7 @@ const SingleDimensionFilter: React.FC<{ }, [props.state.data]); const {setComparisonDimensions} = useCompareEvaluationsState(); - const baselineCallId = props.state.baselineEvaluationCallId; + const baselineCallId = getBaselineCallId(props.state); const compareCallId = Object.keys(props.state.data.evaluationCalls).find( callId => callId !== baselineCallId )!; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ScorecardSection/ScorecardSection.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ScorecardSection/ScorecardSection.tsx index 4b319150fa8..303f640f47e 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ScorecardSection/ScorecardSection.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ScorecardSection/ScorecardSection.tsx @@ -32,6 +32,7 @@ import { } from '../../ecpConstants'; import { EvaluationComparisonState, + getBaselineCallId, getOrderedCallIds, getOrderedModelRefs, } from '../../ecpState'; @@ -414,7 +415,7 @@ export const ScorecardSection: React.FC<{ {evalCallIds.map((evalCallId, mNdx) => { const baseline = resolveSummaryMetricResult( - props.state.baselineEvaluationCallId, + getBaselineCallId(props.state), groupName, metricKey, compositeSummaryMetrics, From f967984c8bf49bf49e1e5a500b338381f2481e4e Mon Sep 17 00:00:00 2001 From: Erica Diaz <156136421+ericakdiaz@users.noreply.github.com> Date: Tue, 10 Dec 2024 16:23:39 -0800 Subject: [PATCH 34/60] chore(weave): Update ToggleButtonGroup to allow disabling individual options (#3194) --- weave-js/src/components/ToggleButtonGroup.tsx | 67 +++++++++++-------- 1 file changed, 38 insertions(+), 29 deletions(-) diff --git a/weave-js/src/components/ToggleButtonGroup.tsx b/weave-js/src/components/ToggleButtonGroup.tsx index 65b2c538975..eca93e95657 100644 --- a/weave-js/src/components/ToggleButtonGroup.tsx +++ b/weave-js/src/components/ToggleButtonGroup.tsx @@ -9,6 +9,7 @@ import {Tailwind} from './Tailwind'; export type ToggleOption = { value: string; icon?: IconName; + isDisabled?: boolean; }; export type ToggleButtonGroupProps = { @@ -37,7 +38,10 @@ export const ToggleButtonGroup = React.forwardRef< } const handleValueChange = (newValue: string) => { - if (newValue !== value) { + if ( + newValue !== value && + options.find(option => option.value === newValue)?.isDisabled !== true + ) { onValueChange(newValue); } }; @@ -49,34 +53,39 @@ export const ToggleButtonGroup = React.forwardRef< onValueChange={handleValueChange} className="flex gap-px" ref={ref}> - {options.map(({value: optionValue, icon}) => ( - - - - ))} + {options.map( + ({value: optionValue, icon, isDisabled: optionIsDisabled}) => ( + + + + ) + )} ); From 4e1e2c9a6da47ceff2fbd113053d959b43b4ce6f Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Tue, 10 Dec 2024 16:30:13 -0800 Subject: [PATCH 35/60] chore(weave): Fix predict detection heuristic for evals (#3196) * init * lint --- .../tsDataModelHooksEvaluationComparison.ts | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/tsDataModelHooksEvaluationComparison.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/tsDataModelHooksEvaluationComparison.ts index 8418f2976ad..847421464e3 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/tsDataModelHooksEvaluationComparison.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/tsDataModelHooksEvaluationComparison.ts @@ -69,7 +69,6 @@ * across different datasets. */ -import _ from 'lodash'; import {sum} from 'lodash'; import {useEffect, useMemo, useRef, useState} from 'react'; @@ -479,17 +478,8 @@ const fetchEvaluationComparisonData = async ( const maybeDigest = parts[1]; if (maybeDigest != null && !maybeDigest.includes('/')) { const rowDigest = maybeDigest; - const possiblePredictNames = [ - 'predict', - 'infer', - 'forward', - 'invoke', - ]; const isProbablyPredictCall = - (_.some(possiblePredictNames, name => - traceCall.op_name.includes(`.${name}:`) - ) && - modelRefs.includes(traceCall.inputs.self)) || + modelRefs.includes(traceCall.inputs.self) || modelRefs.includes(traceCall.op_name); const isProbablyScoreCall = scorerRefs.has(traceCall.op_name); From bf39669bf3ebba977e9b6cd9224c9abce22a3a62 Mon Sep 17 00:00:00 2001 From: Griffin Tarpenning Date: Tue, 10 Dec 2024 16:36:47 -0800 Subject: [PATCH 36/60] fix(ui): Evaluations -- normalize radar data, show real values for bar charts (#3199) --- .../SummaryPlotsSection.tsx | 94 ++++++++++--------- 1 file changed, 50 insertions(+), 44 deletions(-) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/SummaryPlotsSection.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/SummaryPlotsSection.tsx index c23ffcce04d..02c456df850 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/SummaryPlotsSection.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/SummaryPlotsSection.tsx @@ -29,7 +29,7 @@ export const SummaryPlots: React.FC<{ state: EvaluationComparisonState; setSelectedMetrics: (newModel: Record) => void; }> = ({state, setSelectedMetrics}) => { - const {radarData, allMetricNames} = useNormalizedPlotDataFromMetrics(state); + const {radarData, allMetricNames} = usePlotDataFromMetrics(state); const {selectedMetrics} = state; // Initialize selectedMetrics if null @@ -237,11 +237,10 @@ const useFilteredData = ( return data; }, [radarData, selectedMetrics]); -function getMetricValuesFromRadarData(radarData: RadarPlotData): { +function getMetricValuesMap(radarData: RadarPlotData): { [metric: string]: number[]; } { const metricValues: {[metric: string]: number[]} = {}; - // Gather all values for each metric Object.values(radarData).forEach(callData => { Object.entries(callData.metrics).forEach(([metric, value]) => { if (!metricValues[metric]) { @@ -253,37 +252,54 @@ function getMetricValuesFromRadarData(radarData: RadarPlotData): { return metricValues; } -function getMetricMinsFromRadarData(radarData: RadarPlotData): { - [metric: string]: number; +function normalizeMetricValues(values: number[]): { + normalizedValues: number[]; + normalizer: number; } { - const metricValues = getMetricValuesFromRadarData(radarData); - const metricMins: {[metric: string]: number} = {}; - Object.entries(metricValues).forEach(([metric, values]) => { - metricMins[metric] = Math.min(...values); - }); - return metricMins; + const min = Math.min(...values); + const max = Math.max(...values); + + if (min === max) { + return { + normalizedValues: values.map(() => 0.5), + normalizer: 1, + }; + } + + // Handle negative values by shifting + const shiftedValues = min < 0 ? values.map(v => v - min) : values; + const maxValue = min < 0 ? max - min : max; + + const maxPower = Math.ceil(Math.log2(maxValue)); + const normalizer = Math.pow(2, maxPower); + + return { + normalizedValues: shiftedValues.map(v => v / normalizer), + normalizer, + }; } -function normalizeDataForRadarPlot(radarData: RadarPlotData): RadarPlotData { - const metricMins = getMetricMinsFromRadarData(radarData); +function normalizeDataForRadarPlot( + radarDataOriginal: RadarPlotData +): RadarPlotData { + const radarData = Object.fromEntries( + Object.entries(radarDataOriginal).map(([callId, callData]) => [ + callId, + {...callData, metrics: {...callData.metrics}}, + ]) + ); - const normalizedData: RadarPlotData = {}; - Object.entries(radarData).forEach(([callId, callData]) => { - normalizedData[callId] = { - name: callData.name, - color: callData.color, - metrics: {}, - }; + const metricValues = getMetricValuesMap(radarData); - Object.entries(callData.metrics).forEach(([metric, value]) => { - const min = metricMins[metric]; - // Only shift values if there are negative values - const normalizedValue = min < 0 ? value - min : value; - normalizedData[callId].metrics[metric] = normalizedValue; + // Normalize each metric independently + Object.entries(metricValues).forEach(([metric, values]) => { + const {normalizedValues} = normalizeMetricValues(values); + Object.values(radarData).forEach((callData, index) => { + callData.metrics[metric] = normalizedValues[index]; }); }); - return normalizedData; + return radarData; } const useBarPlotData = (filteredData: RadarPlotData) => @@ -317,7 +333,9 @@ const useBarPlotData = (filteredData: RadarPlotData) => type: 'bar', y: metricBin.values, x: metricBin.callIds, - text: metricBin.values.map(value => value.toFixed(3)), + text: metricBin.values.map(value => + Number.isInteger(value) ? value.toString() : value.toFixed(3) + ), textposition: 'outside', textfont: {size: 14, color: 'black'}, name: metric, @@ -408,16 +426,7 @@ const usePaginatedPlots = ( return {plotsToShow, totalPlots, startIndex, endIndex, totalPages}; }; -function normalizeValues(values: Array): number[] { - // find the max value - // find the power of 2 that is greater than the max value - // divide all values by that power of 2 - const maxVal = Math.max(...(values.filter(v => v !== undefined) as number[])); - const maxPower = Math.ceil(Math.log2(maxVal)); - return values.map(val => (val ? val / 2 ** maxPower : 0)); -} - -const useNormalizedPlotDataFromMetrics = ( +const usePlotDataFromMetrics = ( state: EvaluationComparisonState ): {radarData: RadarPlotData; allMetricNames: Set} => { const compositeMetrics = useMemo(() => { @@ -428,7 +437,7 @@ const useNormalizedPlotDataFromMetrics = ( }, [state]); return useMemo(() => { - const normalizedMetrics = Object.values(compositeMetrics) + const metrics = Object.values(compositeMetrics) .map(scoreGroup => Object.values(scoreGroup.metrics)) .flat() .map(metric => { @@ -449,11 +458,8 @@ const useNormalizedPlotDataFromMetrics = ( return val; } }); - const normalizedValues = normalizeValues(values); const evalScores: {[evalCallId: string]: number | undefined} = - Object.fromEntries( - callIds.map((key, i) => [key, normalizedValues[i]]) - ); + Object.fromEntries(callIds.map((key, i) => [key, values[i]])); const metricLabel = flattenedDimensionPath( Object.values(metric.scorerRefs)[0].metric @@ -472,7 +478,7 @@ const useNormalizedPlotDataFromMetrics = ( name: evalCall.name, color: evalCall.color, metrics: Object.fromEntries( - normalizedMetrics.map(metric => { + metrics.map(metric => { return [ metric.metricLabel, metric.evalScores[evalCall.callId] ?? 0, @@ -483,7 +489,7 @@ const useNormalizedPlotDataFromMetrics = ( ]; }) ); - const allMetricNames = new Set(normalizedMetrics.map(m => m.metricLabel)); + const allMetricNames = new Set(metrics.map(m => m.metricLabel)); return {radarData, allMetricNames}; }, [callIds, compositeMetrics, state.data.evaluationCalls]); }; From af81503aa330c0659a8c65b316b1071782a543e6 Mon Sep 17 00:00:00 2001 From: Jeff Raubitschek Date: Wed, 11 Dec 2024 12:34:22 -0800 Subject: [PATCH 37/60] chore(weave): update db to use replicated tables (#3148) --- .../test_clickhouse_trace_server_migrator.py | 230 ++++++++++++++++++ .../clickhouse_trace_server_migrator.py | 132 ++++++++-- 2 files changed, 336 insertions(+), 26 deletions(-) create mode 100644 tests/trace_server/test_clickhouse_trace_server_migrator.py diff --git a/tests/trace_server/test_clickhouse_trace_server_migrator.py b/tests/trace_server/test_clickhouse_trace_server_migrator.py new file mode 100644 index 00000000000..3a6a92f2479 --- /dev/null +++ b/tests/trace_server/test_clickhouse_trace_server_migrator.py @@ -0,0 +1,230 @@ +import types +from unittest.mock import Mock, call, patch + +import pytest + +from weave.trace_server import clickhouse_trace_server_migrator as trace_server_migrator +from weave.trace_server.clickhouse_trace_server_migrator import MigrationError + + +@pytest.fixture +def mock_costs(): + with patch( + "weave.trace_server.costs.insert_costs.should_insert_costs", return_value=False + ) as mock_should_insert: + with patch( + "weave.trace_server.costs.insert_costs.get_current_costs", return_value=[] + ) as mock_get_costs: + yield + + +@pytest.fixture +def migrator(): + ch_client = Mock() + migrator = trace_server_migrator.ClickHouseTraceServerMigrator(ch_client) + migrator._get_migration_status = Mock() + migrator._get_migrations = Mock() + migrator._determine_migrations_to_apply = Mock() + migrator._update_migration_status = Mock() + ch_client.command.reset_mock() + return migrator + + +def test_apply_migrations_with_target_version(mock_costs, migrator, tmp_path): + # Setup + migrator._get_migration_status.return_value = { + "curr_version": 1, + "partially_applied_version": None, + } + migrator._get_migrations.return_value = { + "1": {"up": "1.up.sql", "down": "1.down.sql"}, + "2": {"up": "2.up.sql", "down": "2.down.sql"}, + } + migrator._determine_migrations_to_apply.return_value = [(2, "2.up.sql")] + + # Create a temporary migration file + migration_dir = tmp_path / "migrations" + migration_dir.mkdir() + migration_file = migration_dir / "2.up.sql" + migration_file.write_text( + "CREATE TABLE test1 (id Int32);\nCREATE TABLE test2 (id Int32);" + ) + + # Mock the migration directory path + with patch("os.path.dirname") as mock_dirname: + mock_dirname.return_value = str(tmp_path) + + # Execute + migrator.apply_migrations("test_db", target_version=2) + + # Verify + migrator._get_migration_status.assert_called_once_with("test_db") + migrator._get_migrations.assert_called_once() + migrator._determine_migrations_to_apply.assert_called_once_with( + 1, migrator._get_migrations.return_value, 2 + ) + + # Verify migration execution + assert migrator._update_migration_status.call_count == 2 + migrator._update_migration_status.assert_has_calls( + [call("test_db", 2, is_start=True), call("test_db", 2, is_start=False)] + ) + + # Verify the actual SQL commands were executed + ch_client = migrator.ch_client + assert ch_client.command.call_count == 2 + ch_client.command.assert_has_calls( + [call("CREATE TABLE test1 (id Int32)"), call("CREATE TABLE test2 (id Int32)")] + ) + + +def test_execute_migration_command(mock_costs, migrator): + # Setup + ch_client = migrator.ch_client + ch_client.database = "original_db" + + # Execute + migrator._execute_migration_command("test_db", "CREATE TABLE test (id Int32)") + + # Verify + assert ch_client.database == "original_db" # Should restore original database + ch_client.command.assert_called_once_with("CREATE TABLE test (id Int32)") + + +def test_migration_replicated(mock_costs, migrator): + ch_client = migrator.ch_client + orig = "CREATE TABLE test (id String, project_id String) ENGINE = MergeTree ORDER BY (project_id, id);" + migrator._execute_migration_command("test_db", orig) + ch_client.command.assert_called_once_with(orig) + + +def test_update_migration_status(mock_costs, migrator): + # Don't mock _update_migration_status for this test + migrator._update_migration_status = types.MethodType( + trace_server_migrator.ClickHouseTraceServerMigrator._update_migration_status, + migrator, + ) + + # Test start of migration + migrator._update_migration_status("test_db", 2, is_start=True) + migrator.ch_client.command.assert_called_with( + "ALTER TABLE db_management.migrations UPDATE partially_applied_version = 2 WHERE db_name = 'test_db'" + ) + + # Test end of migration + migrator._update_migration_status("test_db", 2, is_start=False) + migrator.ch_client.command.assert_called_with( + "ALTER TABLE db_management.migrations UPDATE curr_version = 2, partially_applied_version = NULL WHERE db_name = 'test_db'" + ) + + +def test_is_safe_identifier(mock_costs, migrator): + # Valid identifiers + assert migrator._is_safe_identifier("test_db") + assert migrator._is_safe_identifier("my_db123") + assert migrator._is_safe_identifier("db.table") + + # Invalid identifiers + assert not migrator._is_safe_identifier("test-db") + assert not migrator._is_safe_identifier("db;") + assert not migrator._is_safe_identifier("db'name") + assert not migrator._is_safe_identifier("db/*") + + +def test_create_db_sql_validation(mock_costs, migrator): + # Test invalid database name + with pytest.raises(MigrationError, match="Invalid database name"): + migrator._create_db_sql("test;db") + + # Test replicated mode with invalid values + migrator.replicated = True + migrator.replicated_cluster = "test;cluster" + with pytest.raises(MigrationError, match="Invalid cluster name"): + migrator._create_db_sql("test_db") + + migrator.replicated_cluster = "test_cluster" + migrator.replicated_path = "/clickhouse/bad;path/{db}" + with pytest.raises(MigrationError, match="Invalid replicated path"): + migrator._create_db_sql("test_db") + + +def test_create_db_sql_non_replicated(mock_costs, migrator): + # Test non-replicated mode + migrator.replicated = False + sql = migrator._create_db_sql("test_db") + assert sql.strip() == "CREATE DATABASE IF NOT EXISTS test_db" + + +def test_create_db_sql_replicated(mock_costs, migrator): + # Test replicated mode + migrator.replicated = True + migrator.replicated_path = "/clickhouse/tables/{db}" + migrator.replicated_cluster = "test_cluster" + + sql = migrator._create_db_sql("test_db") + expected = """ + CREATE DATABASE IF NOT EXISTS test_db ON CLUSTER test_cluster ENGINE=Replicated('/clickhouse/tables/test_db', '{shard}', '{replica}') + """.strip() + assert sql.strip() == expected + + +def test_format_replicated_sql_non_replicated(mock_costs, migrator): + # Test that SQL is unchanged when replicated=False + migrator.replicated = False + test_cases = [ + "CREATE TABLE test (id Int32) ENGINE = MergeTree", + "CREATE TABLE test (id Int32) ENGINE = SummingMergeTree", + "CREATE TABLE test (id Int32) ENGINE=ReplacingMergeTree", + ] + + for sql in test_cases: + assert migrator._format_replicated_sql(sql) == sql + + +def test_format_replicated_sql_replicated(mock_costs, migrator): + # Test that MergeTree engines are converted to Replicated variants + migrator.replicated = True + + test_cases = [ + ( + "CREATE TABLE test (id Int32) ENGINE = MergeTree", + "CREATE TABLE test (id Int32) ENGINE = ReplicatedMergeTree", + ), + ( + "CREATE TABLE test (id Int32) ENGINE = SummingMergeTree", + "CREATE TABLE test (id Int32) ENGINE = ReplicatedSummingMergeTree", + ), + ( + "CREATE TABLE test (id Int32) ENGINE=ReplacingMergeTree", + "CREATE TABLE test (id Int32) ENGINE = ReplicatedReplacingMergeTree", + ), + # Test with extra whitespace + ( + "CREATE TABLE test (id Int32) ENGINE = MergeTree", + "CREATE TABLE test (id Int32) ENGINE = ReplicatedMergeTree", + ), + # Test with parameters + ( + "CREATE TABLE test (id Int32) ENGINE = MergeTree()", + "CREATE TABLE test (id Int32) ENGINE = ReplicatedMergeTree()", + ), + ] + + for input_sql, expected_sql in test_cases: + assert migrator._format_replicated_sql(input_sql) == expected_sql + + +def test_format_replicated_sql_non_mergetree(mock_costs, migrator): + # Test that non-MergeTree engines are left unchanged + migrator.replicated = True + + test_cases = [ + "CREATE TABLE test (id Int32) ENGINE = Memory", + "CREATE TABLE test (id Int32) ENGINE = Log", + "CREATE TABLE test (id Int32) ENGINE = TinyLog", + # This should not be changed as it's not a complete word match + "CREATE TABLE test (id Int32) ENGINE = MyMergeTreeCustom", + ] + + for sql in test_cases: + assert migrator._format_replicated_sql(sql) == sql diff --git a/weave/trace_server/clickhouse_trace_server_migrator.py b/weave/trace_server/clickhouse_trace_server_migrator.py index 30dffe89365..4336630bf50 100644 --- a/weave/trace_server/clickhouse_trace_server_migrator.py +++ b/weave/trace_server/clickhouse_trace_server_migrator.py @@ -1,6 +1,7 @@ # Clickhouse Trace Server Manager import logging import os +import re from typing import Optional from clickhouse_connect.driver.client import Client as CHClient @@ -9,6 +10,11 @@ logger = logging.getLogger(__name__) +# These settings are only used when `replicated` mode is enabled for +# self managed clickhouse instances. +DEFAULT_REPLICATED_PATH = "/clickhouse/tables/{db}" +DEFAULT_REPLICATED_CLUSTER = "weave_cluster" + class MigrationError(RuntimeError): """Raised when a migration error occurs.""" @@ -16,15 +22,77 @@ class MigrationError(RuntimeError): class ClickHouseTraceServerMigrator: ch_client: CHClient + replicated: bool + replicated_path: str + replicated_cluster: str def __init__( self, ch_client: CHClient, + replicated: Optional[bool] = None, + replicated_path: Optional[str] = None, + replicated_cluster: Optional[str] = None, ): super().__init__() self.ch_client = ch_client + self.replicated = False if replicated is None else replicated + self.replicated_path = ( + DEFAULT_REPLICATED_PATH if replicated_path is None else replicated_path + ) + self.replicated_cluster = ( + DEFAULT_REPLICATED_CLUSTER + if replicated_cluster is None + else replicated_cluster + ) self._initialize_migration_db() + def _is_safe_identifier(self, value: str) -> bool: + """Check if a string is safe to use as an identifier in SQL.""" + return bool(re.match(r"^[a-zA-Z0-9_\.]+$", value)) + + def _format_replicated_sql(self, sql_query: str) -> str: + """Format SQL query to use replicated engines if replicated mode is enabled.""" + if not self.replicated: + return sql_query + + # Match "ENGINE = MergeTree" followed by word boundary + pattern = r"ENGINE\s*=\s*(\w+)?MergeTree\b" + + def replace_engine(match: re.Match[str]) -> str: + engine_prefix = match.group(1) or "" + return f"ENGINE = Replicated{engine_prefix}MergeTree" + + return re.sub(pattern, replace_engine, sql_query, flags=re.IGNORECASE) + + def _create_db_sql(self, db_name: str) -> str: + """Geneate SQL database create string for normal and replicated databases.""" + if not self._is_safe_identifier(db_name): + raise MigrationError(f"Invalid database name: {db_name}") + + replicated_engine = "" + replicated_cluster = "" + if self.replicated: + if not self._is_safe_identifier(self.replicated_cluster): + raise MigrationError(f"Invalid cluster name: {self.replicated_cluster}") + + replicated_path = self.replicated_path.replace("{db}", db_name) + if not all( + self._is_safe_identifier(part) + for part in replicated_path.split("/") + if part + ): + raise MigrationError(f"Invalid replicated path: {replicated_path}") + + replicated_cluster = f" ON CLUSTER {self.replicated_cluster}" + replicated_engine = ( + f" ENGINE=Replicated('{replicated_path}', '{{shard}}', '{{replica}}')" + ) + + create_db_sql = f""" + CREATE DATABASE IF NOT EXISTS {db_name}{replicated_cluster}{replicated_engine} + """ + return create_db_sql + def apply_migrations( self, target_db: str, target_version: Optional[int] = None ) -> None: @@ -46,20 +114,15 @@ def apply_migrations( return logger.info(f"Migrations to apply: {migrations_to_apply}") if status["curr_version"] == 0: - self.ch_client.command(f"CREATE DATABASE IF NOT EXISTS {target_db}") + self.ch_client.command(self._create_db_sql(target_db)) for target_version, migration_file in migrations_to_apply: self._apply_migration(target_db, target_version, migration_file) if should_insert_costs(status["curr_version"], target_version): insert_costs(self.ch_client, target_db) def _initialize_migration_db(self) -> None: - self.ch_client.command( - """ - CREATE DATABASE IF NOT EXISTS db_management - """ - ) - self.ch_client.command( - """ + self.ch_client.command(self._create_db_sql("db_management")) + create_table_sql = """ CREATE TABLE IF NOT EXISTS db_management.migrations ( db_name String, @@ -69,7 +132,7 @@ def _initialize_migration_db(self) -> None: ENGINE = MergeTree() ORDER BY (db_name) """ - ) + self.ch_client.command(self._format_replicated_sql(create_table_sql)) def _get_migration_status(self, db_name: str) -> dict: column_names = ["db_name", "curr_version", "partially_applied_version"] @@ -184,31 +247,48 @@ def _determine_migrations_to_apply( return [] + def _execute_migration_command(self, target_db: str, command: str) -> None: + """Execute a single migration command in the context of the target database.""" + command = command.strip() + if len(command) == 0: + return + curr_db = self.ch_client.database + self.ch_client.database = target_db + self.ch_client.command(self._format_replicated_sql(command)) + self.ch_client.database = curr_db + + def _update_migration_status( + self, target_db: str, target_version: int, is_start: bool = True + ) -> None: + """Update the migration status in db_management.migrations table.""" + if is_start: + self.ch_client.command( + f"ALTER TABLE db_management.migrations UPDATE partially_applied_version = {target_version} WHERE db_name = '{target_db}'" + ) + else: + self.ch_client.command( + f"ALTER TABLE db_management.migrations UPDATE curr_version = {target_version}, partially_applied_version = NULL WHERE db_name = '{target_db}'" + ) + def _apply_migration( self, target_db: str, target_version: int, migration_file: str ) -> None: logger.info(f"Applying migration {migration_file} to `{target_db}`") migration_dir = os.path.join(os.path.dirname(__file__), "migrations") migration_file_path = os.path.join(migration_dir, migration_file) + with open(migration_file_path) as f: migration_sql = f.read() - self.ch_client.command( - f""" - ALTER TABLE db_management.migrations UPDATE partially_applied_version = {target_version} WHERE db_name = '{target_db}' - """ - ) + + # Mark migration as partially applied + self._update_migration_status(target_db, target_version, is_start=True) + + # Execute each command in the migration migration_sub_commands = migration_sql.split(";") for command in migration_sub_commands: - command = command.strip() - if len(command) == 0: - continue - curr_db = self.ch_client.database - self.ch_client.database = target_db - self.ch_client.command(command) - self.ch_client.database = curr_db - self.ch_client.command( - f""" - ALTER TABLE db_management.migrations UPDATE curr_version = {target_version}, partially_applied_version = NULL WHERE db_name = '{target_db}' - """ - ) + self._execute_migration_command(target_db, command) + + # Mark migration as fully applied + self._update_migration_status(target_db, target_version, is_start=False) + logger.info(f"Migration {migration_file} applied to `{target_db}`") From 4c6db183be9b356dfc6781670c5b8f5fb6fd7532 Mon Sep 17 00:00:00 2001 From: Josiah Lee Date: Wed, 11 Dec 2024 12:54:59 -0800 Subject: [PATCH 38/60] add choices drawer (#3203) --- .../Browse3/pages/ChatView/ChoiceView.tsx | 8 +- .../Browse3/pages/ChatView/ChoicesDrawer.tsx | 104 ++++++++++++++++++ .../Browse3/pages/ChatView/ChoicesView.tsx | 39 ++++--- .../pages/ChatView/ChoicesViewCarousel.tsx | 27 +++-- .../pages/ChatView/ChoicesViewLinear.tsx | 73 ------------ .../Home/Browse3/pages/ChatView/types.ts | 2 - 6 files changed, 148 insertions(+), 105 deletions(-) create mode 100644 weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesDrawer.tsx delete mode 100644 weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesViewLinear.tsx diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoiceView.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoiceView.tsx index c8143f9549e..d1a2c59d5d0 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoiceView.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoiceView.tsx @@ -6,15 +6,21 @@ import {Choice} from './types'; type ChoiceViewProps = { choice: Choice; isStructuredOutput?: boolean; + isNested?: boolean; }; -export const ChoiceView = ({choice, isStructuredOutput}: ChoiceViewProps) => { +export const ChoiceView = ({ + choice, + isStructuredOutput, + isNested, +}: ChoiceViewProps) => { const {message} = choice; return ( ); diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesDrawer.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesDrawer.tsx new file mode 100644 index 00000000000..16d6897c27b --- /dev/null +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesDrawer.tsx @@ -0,0 +1,104 @@ +import {Box, Drawer} from '@mui/material'; +import {MOON_200} from '@wandb/weave/common/css/color.styles'; +import {Tag} from '@wandb/weave/components/Tag'; +import {Tailwind} from '@wandb/weave/components/Tailwind'; +import React from 'react'; + +import {Button} from '../../../../../Button'; +import {ChoiceView} from './ChoiceView'; +import {Choice} from './types'; + +type ChoicesDrawerProps = { + choices: Choice[]; + isStructuredOutput?: boolean; + isDrawerOpen: boolean; + setIsDrawerOpen: React.Dispatch>; + selectedChoiceIndex: number; + setSelectedChoiceIndex: (choiceIndex: number) => void; +}; + +export const ChoicesDrawer = ({ + choices, + isStructuredOutput, + isDrawerOpen, + setIsDrawerOpen, + selectedChoiceIndex, + setSelectedChoiceIndex, +}: ChoicesDrawerProps) => { + return ( + setIsDrawerOpen(false)} + title="Choices" + anchor="right" + sx={{ + '& .MuiDrawer-paper': {mt: '60px', width: '400px'}, + }}> + + + Responses + + + ) : ( + + )} +
+ +
+ ))} +
+ + + ); +}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesView.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesView.tsx index c22df7c63d7..5ddc7f12202 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesView.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesView.tsx @@ -1,9 +1,9 @@ import React, {useState} from 'react'; +import {ChoicesDrawer} from './ChoicesDrawer'; import {ChoicesViewCarousel} from './ChoicesViewCarousel'; -import {ChoicesViewLinear} from './ChoicesViewLinear'; import {ChoiceView} from './ChoiceView'; -import {Choice, ChoicesMode} from './types'; +import {Choice} from './types'; type ChoicesViewProps = { choices: Choice[]; @@ -14,7 +14,12 @@ export const ChoicesView = ({ choices, isStructuredOutput, }: ChoicesViewProps) => { - const [mode, setMode] = useState('linear'); + const [isDrawerOpen, setIsDrawerOpen] = useState(false); + const [localSelectedChoiceIndex, setLocalSelectedChoiceIndex] = useState(0); + + const handleSetSelectedChoiceIndex = (choiceIndex: number) => { + setLocalSelectedChoiceIndex(choiceIndex); + }; if (choices.length === 0) { return null; @@ -26,20 +31,20 @@ export const ChoicesView = ({ } return ( <> - {mode === 'linear' && ( - - )} - {mode === 'carousel' && ( - - )} + + ); }; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesViewCarousel.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesViewCarousel.tsx index f4a52fc6801..a34932dea17 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesViewCarousel.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesViewCarousel.tsx @@ -1,34 +1,37 @@ -import React, {useState} from 'react'; +import React from 'react'; import {Button} from '../../../../../Button'; import {ChoiceView} from './ChoiceView'; -import {Choice, ChoicesMode} from './types'; +import {Choice} from './types'; type ChoicesViewCarouselProps = { choices: Choice[]; isStructuredOutput?: boolean; - setMode: React.Dispatch>; + setIsDrawerOpen: React.Dispatch>; + selectedChoiceIndex: number; + setSelectedChoiceIndex: (choiceIndex: number) => void; }; export const ChoicesViewCarousel = ({ choices, isStructuredOutput, - setMode, + setIsDrawerOpen, + selectedChoiceIndex, + setSelectedChoiceIndex, }: ChoicesViewCarouselProps) => { - const [step, setStep] = useState(0); - const onNext = () => { - setStep((step + 1) % choices.length); + setSelectedChoiceIndex((selectedChoiceIndex + 1) % choices.length); }; const onBack = () => { - const newStep = step === 0 ? choices.length - 1 : step - 1; - setStep(newStep); + const newStep = + selectedChoiceIndex === 0 ? choices.length - 1 : selectedChoiceIndex - 1; + setSelectedChoiceIndex(newStep); }; return ( <>
@@ -37,7 +40,7 @@ export const ChoicesViewCarousel = ({ size="small" variant="quiet" icon="expand-uncollapse" - onClick={() => setMode('linear')} + onClick={() => setIsDrawerOpen(true)} tooltip="Switch to linear view" />
@@ -48,7 +51,7 @@ export const ChoicesViewCarousel = ({ size="small" onClick={onBack} /> - {step + 1} of {choices.length} + {selectedChoiceIndex + 1} of {choices.length}
-
- - - ))} - - ); -}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/types.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/types.ts index b5696055712..3bbe65a5baf 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/types.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/types.ts @@ -113,5 +113,3 @@ export type Chat = { request: ChatRequest | null; result: ChatCompletion | null; }; - -export type ChoicesMode = 'linear' | 'carousel'; From 0b2d99c6afb9a0c3c7e95e86ae03b7faa43d8acd Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Wed, 11 Dec 2024 12:56:01 -0800 Subject: [PATCH 39/60] chore(weave): hide admin only scorers tabs (#3207) * re-hide * init * init --- .../Home/Browse3/pages/CallPage/CallPage.tsx | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx index 484b038c193..a3315266f65 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx @@ -1,5 +1,4 @@ import Box from '@mui/material/Box'; -import {useViewerInfo} from '@wandb/weave/common/hooks/useViewerInfo'; import {Loading} from '@wandb/weave/components/Loading'; import {urlPrefixed} from '@wandb/weave/config'; import {useViewTraceEvent} from '@wandb/weave/integrations/analytics/useViewEvents'; @@ -71,8 +70,10 @@ export const CallPage: FC<{ }; export const useShowRunnableUI = () => { - const viewerInfo = useViewerInfo(); - return viewerInfo.loading ? false : viewerInfo.userInfo?.admin; + return false; + // Uncomment to re-enable + // const viewerInfo = useViewerInfo(); + // return viewerInfo.loading ? false : viewerInfo.userInfo?.admin; }; const useCallTabs = (call: CallSchema) => { From 3e43f5f7e5eace87e857889c1dce814983febb42 Mon Sep 17 00:00:00 2001 From: Josiah Lee Date: Wed, 11 Dec 2024 13:14:41 -0800 Subject: [PATCH 40/60] chore(weave): fix choices in weave (#3204) * add choices drawer * fix choices in playground --- .../Browse3/pages/ChatView/ChoiceView.tsx | 4 ++- .../Browse3/pages/ChatView/ChoicesDrawer.tsx | 1 + .../Browse3/pages/ChatView/ChoicesView.tsx | 5 ++++ .../pages/ChatView/ChoicesViewCarousel.tsx | 1 + .../Browse3/pages/ChatView/MessagePanel.tsx | 8 +++--- .../PlaygroundMessagePanelButtons.tsx | 10 +++---- .../ChatView/PlaygroundMessagePanelEditor.tsx | 10 +++---- .../PlaygroundChat/PlaygroundChat.tsx | 10 +++++-- .../useChatCompletionFunctions.tsx | 28 ++++++++++++------- .../PlaygroundChat/useChatFunctions.tsx | 13 ++------- .../PlaygroundPage/PlaygroundContext.tsx | 7 +++-- .../Browse3/pages/PlaygroundPage/types.ts | 1 + .../PlaygroundPage/usePlaygroundState.ts | 1 + 13 files changed, 60 insertions(+), 39 deletions(-) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoiceView.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoiceView.tsx index d1a2c59d5d0..9e24e7e0a4a 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoiceView.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoiceView.tsx @@ -7,12 +7,14 @@ type ChoiceViewProps = { choice: Choice; isStructuredOutput?: boolean; isNested?: boolean; + choiceIndex?: number; }; export const ChoiceView = ({ choice, isStructuredOutput, isNested, + choiceIndex, }: ChoiceViewProps) => { const {message} = choice; return ( @@ -21,7 +23,7 @@ export const ChoiceView = ({ message={message} isStructuredOutput={isStructuredOutput} isNested={isNested} - isChoice + choiceIndex={choiceIndex} /> ); }; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesDrawer.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesDrawer.tsx index 16d6897c27b..1e7571cea53 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesDrawer.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesDrawer.tsx @@ -93,6 +93,7 @@ export const ChoicesDrawer = ({ diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesView.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesView.tsx index 5ddc7f12202..138ca10c7e8 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesView.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesView.tsx @@ -1,5 +1,6 @@ import React, {useState} from 'react'; +import {usePlaygroundContext} from '../PlaygroundPage/PlaygroundContext'; import {ChoicesDrawer} from './ChoicesDrawer'; import {ChoicesViewCarousel} from './ChoicesViewCarousel'; import {ChoiceView} from './ChoiceView'; @@ -14,11 +15,15 @@ export const ChoicesView = ({ choices, isStructuredOutput, }: ChoicesViewProps) => { + const {setSelectedChoiceIndex: setGlobalSelectedChoiceIndex} = + usePlaygroundContext(); + const [isDrawerOpen, setIsDrawerOpen] = useState(false); const [localSelectedChoiceIndex, setLocalSelectedChoiceIndex] = useState(0); const handleSetSelectedChoiceIndex = (choiceIndex: number) => { setLocalSelectedChoiceIndex(choiceIndex); + setGlobalSelectedChoiceIndex(choiceIndex); }; if (choices.length === 0) { diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesViewCarousel.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesViewCarousel.tsx index a34932dea17..18760817665 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesViewCarousel.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesViewCarousel.tsx @@ -33,6 +33,7 @@ export const ChoicesViewCarousel = ({
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/MessagePanel.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/MessagePanel.tsx index cc1911b60d4..1e778727522 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/MessagePanel.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/MessagePanel.tsx @@ -15,7 +15,7 @@ type MessagePanelProps = { index: number; message: Message; isStructuredOutput?: boolean; - isChoice?: boolean; + choiceIndex?: number; isNested?: boolean; pendingToolResponseId?: string; }; @@ -24,7 +24,7 @@ export const MessagePanel = ({ index, message, isStructuredOutput, - isChoice, + choiceIndex, isNested, // The id of the tool call response that is pending // If the tool call response is pending, the editor will be shown automatically @@ -120,7 +120,7 @@ export const MessagePanel = ({ ; @@ -17,7 +17,7 @@ export const PlaygroundMessagePanelButtons: React.FC< PlaygroundMessagePanelButtonsProps > = ({ index, - isChoice, + choiceIndex, isTool, hasContent, contentRef, @@ -32,7 +32,7 @@ export const PlaygroundMessagePanelButtons: React.FC< variant="quiet" size="small" startIcon="randomize-reset-reload" - onClick={() => retry?.(index, isChoice)} + onClick={() => retry?.(index, choiceIndex)} tooltip={ !hasContent ? 'We currently do not support retrying functions' @@ -64,8 +64,8 @@ export const PlaygroundMessagePanelButtons: React.FC< size="small" startIcon="delete" onClick={() => { - if (isChoice) { - deleteChoice?.(index); + if (choiceIndex !== undefined) { + deleteChoice?.(index, choiceIndex); } else { deleteMessage?.(index, responseIndexes); } diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/PlaygroundMessagePanelEditor.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/PlaygroundMessagePanelEditor.tsx index 746b033579a..aa519c9659b 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/PlaygroundMessagePanelEditor.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/PlaygroundMessagePanelEditor.tsx @@ -13,7 +13,7 @@ type PlaygroundMessagePanelEditorProps = { pendingToolResponseId?: string; message: Message; index: number; - isChoice: boolean; + choiceIndex?: number; setEditorHeight: (height: number | null) => void; }; @@ -21,7 +21,7 @@ export const PlaygroundMessagePanelEditor: React.FC< PlaygroundMessagePanelEditorProps > = ({ index, - isChoice, + choiceIndex, setEditorHeight, editorHeight, isNested, @@ -45,10 +45,10 @@ export const PlaygroundMessagePanelEditor: React.FC< }, [initialContent]); const handleSave = () => { - if (isChoice) { - editChoice?.(index, { + if (choiceIndex !== undefined) { + editChoice?.(choiceIndex, { + ...message, content: editedContent, - role: message.role, }); } else { editMessage?.(index, { diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/PlaygroundChat.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/PlaygroundChat.tsx index 94cd3c17644..b6b6e7c420d 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/PlaygroundChat.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/PlaygroundChat.tsx @@ -161,8 +161,8 @@ export const PlaygroundChat = ({ addMessage: newMessage => addMessage(idx, newMessage), editChoice: (choiceIndex, newChoice) => editChoice(idx, choiceIndex, newChoice), - retry: (messageIndex: number, isChoice?: boolean) => - handleRetry(idx, messageIndex, isChoice), + retry: (messageIndex: number, choiceIndex?: number) => + handleRetry(idx, messageIndex, choiceIndex), sendMessage: ( role: PlaygroundMessageRole, content: string, @@ -170,6 +170,12 @@ export const PlaygroundChat = ({ ) => { handleSend(role, idx, content, toolCallId); }, + setSelectedChoiceIndex: (choiceIndex: number) => + setPlaygroundStateField( + idx, + 'selectedChoiceIndex', + choiceIndex + ), }}> diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/useChatCompletionFunctions.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/useChatCompletionFunctions.tsx index 10c76fc82d1..c73e5d42919 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/useChatCompletionFunctions.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/useChatCompletionFunctions.tsx @@ -70,7 +70,7 @@ export const useChatCompletionFunctions = ( if (callIndex !== undefined && callIndex !== index) { return state; } - const updatedState = appendChoicesToMessages(state); + const updatedState = appendChoiceToMessages(state); if (updatedState.traceCall?.inputs?.messages) { updatedState.traceCall.inputs.messages.push(newMessage); } @@ -99,14 +99,14 @@ export const useChatCompletionFunctions = ( const handleRetry = async ( callIndex: number, messageIndex: number, - isChoice?: boolean + choiceIndex?: number ) => { try { setIsLoading(true); const updatedStates = playgroundStates.map((state, index) => { if (index === callIndex) { - if (isChoice) { - return appendChoicesToMessages(state); + if (choiceIndex !== undefined) { + return appendChoiceToMessages(state, choiceIndex); } const updatedState = JSON.parse(JSON.stringify(state)); if (updatedState.traceCall?.inputs?.messages) { @@ -203,17 +203,25 @@ const handleUpdateCallWithResponse = ( }; }; -const appendChoicesToMessages = (state: PlaygroundState): PlaygroundState => { +const appendChoiceToMessages = ( + state: PlaygroundState, + choiceIndex?: number +): PlaygroundState => { const updatedState = JSON.parse(JSON.stringify(state)); if ( updatedState.traceCall?.inputs?.messages && updatedState.traceCall.output?.choices ) { - updatedState.traceCall.output.choices.forEach((choice: any) => { - if (choice.message) { - updatedState.traceCall.inputs.messages.push(choice.message); - } - }); + if (choiceIndex !== undefined) { + updatedState.traceCall.inputs.messages.push( + updatedState.traceCall.output.choices[choiceIndex].message + ); + } else { + updatedState.traceCall.inputs.messages.push( + updatedState.traceCall.output.choices[updatedState.selectedChoiceIndex] + .message + ); + } updatedState.traceCall.output.choices = undefined; } return updatedState; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/useChatFunctions.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/useChatFunctions.tsx index e84e2f75d4b..804670a1dc3 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/useChatFunctions.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/useChatFunctions.tsx @@ -114,16 +114,9 @@ export const useChatFunctions = ( newTraceCall?.output && Array.isArray((newTraceCall.output as TraceCallOutput).choices) ) { - // Delete the old choice - (newTraceCall.output as TraceCallOutput).choices!.splice( - choiceIndex, - 1 - ); - - // Add the new choice as a message - newTraceCall.inputs = newTraceCall.inputs ?? {}; - newTraceCall.inputs.messages = newTraceCall.inputs.messages ?? []; - newTraceCall.inputs.messages.push(newChoice); + // Replace the choice + (newTraceCall.output as TraceCallOutput).choices![choiceIndex].message = + newChoice; } return newTraceCall; }); diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundContext.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundContext.tsx index a8176292d1a..31369602560 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundContext.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundContext.tsx @@ -10,14 +10,16 @@ export type PlaygroundContextType = { deleteMessage: (messageIndex: number, responseIndexes?: number[]) => void; editChoice: (choiceIndex: number, newChoice: Message) => void; - deleteChoice: (choiceIndex: number) => void; + deleteChoice: (messageIndex: number, choiceIndex: number) => void; - retry: (messageIndex: number, isChoice?: boolean) => void; + retry: (messageIndex: number, choiceIndex?: number) => void; sendMessage: ( role: PlaygroundMessageRole, content: string, toolCallId?: string ) => void; + + setSelectedChoiceIndex: (choiceIndex: number) => void; }; const DEFAULT_CONTEXT: PlaygroundContextType = { @@ -31,6 +33,7 @@ const DEFAULT_CONTEXT: PlaygroundContextType = { retry: () => {}, sendMessage: () => {}, + setSelectedChoiceIndex: () => {}, }; // Create context that can be undefined diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/types.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/types.ts index fa73e87bf45..29495fcf535 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/types.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/types.ts @@ -23,6 +23,7 @@ export type PlaygroundState = { // nTimes: number; maxTokensLimit: number; model: LLMMaxTokensKey; + selectedChoiceIndex: number; }; export type PlaygroundStateKey = keyof PlaygroundState; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/usePlaygroundState.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/usePlaygroundState.ts index 8c556edaef2..b0c06ddc677 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/usePlaygroundState.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/usePlaygroundState.ts @@ -37,6 +37,7 @@ const DEFAULT_PLAYGROUND_STATE = { // nTimes: 1, maxTokensLimit: 16384, model: DEFAULT_MODEL, + selectedChoiceIndex: 0, }; export const usePlaygroundState = () => { From 1a9485f64e4ef7e1909d53cc31239bede8d17e95 Mon Sep 17 00:00:00 2001 From: Josiah Lee Date: Wed, 11 Dec 2024 13:19:30 -0800 Subject: [PATCH 41/60] unhide completion iterations (#3137) --- .../PlaygroundSettings/PlaygroundSettings.tsx | 41 ++++++++++++------- .../Browse3/pages/PlaygroundPage/types.ts | 2 +- .../PlaygroundPage/usePlaygroundState.ts | 10 ++--- 3 files changed, 33 insertions(+), 20 deletions(-) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundSettings/PlaygroundSettings.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundSettings/PlaygroundSettings.tsx index 5a2e8fae32c..e0971d35bfb 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundSettings/PlaygroundSettings.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundSettings/PlaygroundSettings.tsx @@ -59,6 +59,12 @@ export const PlaygroundSettings: React.FC = ({ gap: '4px', mt: 2, }}> + + setPlaygroundStateField(idx, 'responseFormat', value) + } + /> = ({ } /> - - setPlaygroundStateField(idx, 'responseFormat', value) + + setPlaygroundStateField(idx, 'stopSequences', value) } /> + {/* TODO: N times to run is not supported for all models */} + {/* TODO: rerun if this is not supported in the backend */} - setPlaygroundStateField(idx, 'temperature', value) + setPlaygroundStateField(idx, 'nTimes', value) } - label="Temperature" - value={playgroundState.temperature} + label="Completion iterations" + value={playgroundState.nTimes} /> = ({ value={playgroundState.maxTokens} /> - - setPlaygroundStateField(idx, 'stopSequences', value) + + setPlaygroundStateField(idx, 'temperature', value) } + label="Temperature" + value={playgroundState.temperature} /> = ({ label="Presence penalty" value={playgroundState.presencePenalty} /> + { } } } - // if (inputs.n) { - // newState.nTimes = parseInt(inputs.n, 10); - // } + if (inputs.n) { + newState.nTimes = parseInt(inputs.n, 10); + } if (inputs.temperature) { newState.temperature = parseFloat(inputs.temperature); } @@ -148,7 +148,7 @@ export const getInputFromPlaygroundState = (state: PlaygroundState) => { top_p: state.topP, frequency_penalty: state.frequencyPenalty, presence_penalty: state.presencePenalty, - // n: state.nTimes, + n: state.nTimes, response_format: { type: state.responseFormat, }, From 74430483998c6f0ace4e248669bc58a81e54cb2d Mon Sep 17 00:00:00 2001 From: Weave Build Bot Date: Wed, 11 Dec 2024 21:28:01 +0000 Subject: [PATCH 42/60] =?UTF-8?q?Release=20version:=200.51.24-dev0=20?= =?UTF-8?q?=E2=86=92=200.51.24?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pyproject.toml | 2 +- weave/version.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d392d527b60..0f8eec3a30c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -226,7 +226,7 @@ module = "weave_query.*" ignore_errors = true [tool.bumpversion] -current_version = "0.51.24-dev0" +current_version = "0.51.24" parse = """(?x) (?P0|[1-9]\\d*)\\. (?P0|[1-9]\\d*)\\. diff --git a/weave/version.py b/weave/version.py index 5212f6aee7d..9ad7975da7d 100644 --- a/weave/version.py +++ b/weave/version.py @@ -44,4 +44,4 @@ """ -VERSION = "0.51.24-dev0" +VERSION = "0.51.24" From 43e2b7bd9a83ea8e1017f999994e08a933bc77a0 Mon Sep 17 00:00:00 2001 From: Weave Build Bot Date: Wed, 11 Dec 2024 21:28:01 +0000 Subject: [PATCH 43/60] =?UTF-8?q?Release=20version:=200.51.24=20=E2=86=92?= =?UTF-8?q?=200.51.25-dev0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pyproject.toml | 2 +- weave/version.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0f8eec3a30c..f34757b315e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -226,7 +226,7 @@ module = "weave_query.*" ignore_errors = true [tool.bumpversion] -current_version = "0.51.24" +current_version = "0.51.25-dev0" parse = """(?x) (?P0|[1-9]\\d*)\\. (?P0|[1-9]\\d*)\\. diff --git a/weave/version.py b/weave/version.py index 9ad7975da7d..70f670abc21 100644 --- a/weave/version.py +++ b/weave/version.py @@ -44,4 +44,4 @@ """ -VERSION = "0.51.24" +VERSION = "0.51.25-dev0" From 2c93fa653fb2b3c21e7a2b7d3fd40962e17eb02b Mon Sep 17 00:00:00 2001 From: Josiah Lee Date: Wed, 11 Dec 2024 13:53:24 -0800 Subject: [PATCH 44/60] style(weave): restyle carousel view (#3205) * add choices drawer * fix choices in playground * restyle carousel viewer * prettier --- weave-js/src/assets/icons/icon-visible.svg | 4 ++ weave-js/src/components/Icon/Icon.tsx | 5 ++ weave-js/src/components/Icon/index.ts | 1 + weave-js/src/components/Icon/types.ts | 1 + .../Browse3/pages/ChatView/ChoiceView.tsx | 3 + .../pages/ChatView/ChoicesViewCarousel.tsx | 57 ++++++++++--------- .../Browse3/pages/ChatView/MessagePanel.tsx | 3 + .../PlaygroundChat/PlaygroundChat.tsx | 1 + .../PlaygroundChat/useChatFunctions.tsx | 3 + 9 files changed, 50 insertions(+), 28 deletions(-) create mode 100644 weave-js/src/assets/icons/icon-visible.svg diff --git a/weave-js/src/assets/icons/icon-visible.svg b/weave-js/src/assets/icons/icon-visible.svg new file mode 100644 index 00000000000..823c36bb17f --- /dev/null +++ b/weave-js/src/assets/icons/icon-visible.svg @@ -0,0 +1,4 @@ + + + + diff --git a/weave-js/src/components/Icon/Icon.tsx b/weave-js/src/components/Icon/Icon.tsx index 9c020b1bab9..32b9ca65921 100644 --- a/weave-js/src/components/Icon/Icon.tsx +++ b/weave-js/src/components/Icon/Icon.tsx @@ -256,6 +256,7 @@ import {ReactComponent as ImportVersionsLayers} from '../../assets/icons/icon-ve import {ReactComponent as ImportVertexGCP} from '../../assets/icons/icon-vertex-gcp.svg'; import {ReactComponent as ImportVideoPlay} from '../../assets/icons/icon-video-play.svg'; import {ReactComponent as ImportViewGlasses} from '../../assets/icons/icon-view-glasses.svg'; +import {ReactComponent as ImportVisible} from '../../assets/icons/icon-visible.svg'; import {ReactComponent as ImportWandb} from '../../assets/icons/icon-wandb.svg'; import {ReactComponent as ImportWarning} from '../../assets/icons/icon-warning.svg'; import {ReactComponent as ImportWarningAlt} from '../../assets/icons/icon-warning-alt.svg'; @@ -1048,6 +1049,9 @@ export const IconVideoPlay = (props: SVGIconProps) => ( export const IconViewGlasses = (props: SVGIconProps) => ( ); +export const IconVisible = (props: SVGIconProps) => ( + +); export const IconWandb = (props: SVGIconProps) => ( ); @@ -1336,6 +1340,7 @@ const ICON_NAME_TO_ICON: Record = { 'vertex-gcp': IconVertexGCP, 'video-play': IconVideoPlay, 'view-glasses': IconViewGlasses, + visible: IconVisible, wandb: IconWandb, warning: IconWarning, 'warning-alt': IconWarningAlt, diff --git a/weave-js/src/components/Icon/index.ts b/weave-js/src/components/Icon/index.ts index 85ea5332649..08bf7854ad2 100644 --- a/weave-js/src/components/Icon/index.ts +++ b/weave-js/src/components/Icon/index.ts @@ -256,6 +256,7 @@ export { IconVertexGCP, IconVideoPlay, IconViewGlasses, + IconVisible, IconWandb, IconWarning, IconWarningAlt, diff --git a/weave-js/src/components/Icon/types.ts b/weave-js/src/components/Icon/types.ts index 87a1207bc85..7ca30049257 100644 --- a/weave-js/src/components/Icon/types.ts +++ b/weave-js/src/components/Icon/types.ts @@ -255,6 +255,7 @@ export const IconNames = { VertexGCP: 'vertex-gcp', VideoPlay: 'video-play', ViewGlasses: 'view-glasses', + Visible: 'visible', Wandb: 'wandb', Warning: 'warning', WarningAlt: 'warning-alt', diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoiceView.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoiceView.tsx index 9e24e7e0a4a..e511d6fbf5c 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoiceView.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoiceView.tsx @@ -8,6 +8,7 @@ type ChoiceViewProps = { isStructuredOutput?: boolean; isNested?: boolean; choiceIndex?: number; + messageHeader?: React.ReactNode; }; export const ChoiceView = ({ @@ -15,6 +16,7 @@ export const ChoiceView = ({ isStructuredOutput, isNested, choiceIndex, + messageHeader, }: ChoiceViewProps) => { const {message} = choice; return ( @@ -24,6 +26,7 @@ export const ChoiceView = ({ isStructuredOutput={isStructuredOutput} isNested={isNested} choiceIndex={choiceIndex} + messageHeader={messageHeader} /> ); }; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesViewCarousel.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesViewCarousel.tsx index 18760817665..b7dc6eb427d 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesViewCarousel.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesViewCarousel.tsx @@ -29,38 +29,39 @@ export const ChoicesViewCarousel = ({ }; return ( - <> - -
-
+ +
+
-
-
-
- + } + /> ); }; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/MessagePanel.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/MessagePanel.tsx index 1e778727522..f570b2f6295 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/MessagePanel.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/MessagePanel.tsx @@ -18,6 +18,7 @@ type MessagePanelProps = { choiceIndex?: number; isNested?: boolean; pendingToolResponseId?: string; + messageHeader?: React.ReactNode; }; export const MessagePanel = ({ @@ -30,6 +31,7 @@ export const MessagePanel = ({ // If the tool call response is pending, the editor will be shown automatically // and on save the tool call response will be updated and sent to the LLM pendingToolResponseId, + messageHeader, }: MessagePanelProps) => { const [isShowingMore, setIsShowingMore] = useState(false); const [isOverflowing, setIsOverflowing] = useState(false); @@ -116,6 +118,7 @@ export const MessagePanel = ({ 'max-h-[400px]': !isShowingMore, 'max-h-full': isShowingMore, })}> + {messageHeader} {isPlayground && editorHeight ? ( { + console.log('playgroundStates', playgroundStates); const [chatText, setChatText] = useState(''); const [isLoading, setIsLoading] = useState(false); diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/useChatFunctions.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/useChatFunctions.tsx index 804670a1dc3..0ce3ad02b51 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/useChatFunctions.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/useChatFunctions.tsx @@ -43,6 +43,8 @@ export const useChatFunctions = ( messageIndex: number, newMessage: Message ) => { + console.log('editMessage', callIndex, messageIndex, newMessage); + setPlaygroundStateField(callIndex, 'traceCall', prevTraceCall => { const newTraceCall = clearTraceCall( cloneDeep(prevTraceCall as OptionalTraceCallSchema) @@ -106,6 +108,7 @@ export const useChatFunctions = ( choiceIndex: number, newChoice: Message ) => { + console.log('editChoice', callIndex, choiceIndex, newChoice); setPlaygroundStateField(callIndex, 'traceCall', prevTraceCall => { const newTraceCall = clearTraceCall( cloneDeep(prevTraceCall as OptionalTraceCallSchema) From a0f12639451979efc01ef3b31767444542688ea4 Mon Sep 17 00:00:00 2001 From: Andrew Truong Date: Wed, 11 Dec 2024 16:56:46 -0500 Subject: [PATCH 45/60] chore(weave): Add generic iterator for trace server API objects (#3177) --- weave/trace/weave_client.py | 207 ++++++++++++++++++++++-------------- 1 file changed, 129 insertions(+), 78 deletions(-) diff --git a/weave/trace/weave_client.py b/weave/trace/weave_client.py index 0eca3fcbedb..1d5d54b9b23 100644 --- a/weave/trace/weave_client.py +++ b/weave/trace/weave_client.py @@ -10,7 +10,7 @@ from collections.abc import Iterator, Sequence from concurrent.futures import Future from functools import lru_cache -from typing import Any, Callable, cast +from typing import Any, Callable, Generic, Protocol, TypeVar, cast, overload import pydantic from requests import HTTPError @@ -90,6 +90,128 @@ logger = logging.getLogger(__name__) +T = TypeVar("T") +R = TypeVar("R", covariant=True) + + +class FetchFunc(Protocol[T]): + def __call__(self, offset: int, limit: int) -> list[T]: ... + + +TransformFunc = Callable[[T], R] + + +class PaginatedIterator(Generic[T, R]): + """An iterator that fetches pages of items from a server and optionally transforms them + into a more user-friendly type.""" + + def __init__( + self, + fetch_func: FetchFunc[T], + page_size: int = 1000, + transform_func: TransformFunc[T, R] | None = None, + ) -> None: + self.fetch_func = fetch_func + self.page_size = page_size + self.transform_func = transform_func + + if page_size <= 0: + raise ValueError("page_size must be greater than 0") + + @lru_cache + def _fetch_page(self, index: int) -> list[T]: + return self.fetch_func(index * self.page_size, self.page_size) + + @overload + def _get_one(self: PaginatedIterator[T, T], index: int) -> T: ... + @overload + def _get_one(self: PaginatedIterator[T, R], index: int) -> R: ... + def _get_one(self, index: int) -> T | R: + if index < 0: + raise IndexError("Negative indexing not supported") + + page_index = index // self.page_size + page_offset = index % self.page_size + + page = self._fetch_page(page_index) + if page_offset >= len(page): + raise IndexError(f"Index {index} out of range") + + res = page[page_offset] + if transform := self.transform_func: + return transform(res) + return res + + @overload + def _get_slice(self: PaginatedIterator[T, T], key: slice) -> Iterator[T]: ... + @overload + def _get_slice(self: PaginatedIterator[T, R], key: slice) -> Iterator[R]: ... + def _get_slice(self, key: slice) -> Iterator[T] | Iterator[R]: + if (start := key.start or 0) < 0: + raise ValueError("Negative start not supported") + if (stop := key.stop) is not None and stop < 0: + raise ValueError("Negative stop not supported") + if (step := key.step or 1) < 0: + raise ValueError("Negative step not supported") + + i = start + while stop is None or i < stop: + try: + yield self._get_one(i) + except IndexError: + break + i += step + + @overload + def __getitem__(self: PaginatedIterator[T, T], key: int) -> T: ... + @overload + def __getitem__(self: PaginatedIterator[T, R], key: int) -> R: ... + @overload + def __getitem__(self: PaginatedIterator[T, T], key: slice) -> list[T]: ... + @overload + def __getitem__(self: PaginatedIterator[T, R], key: slice) -> list[R]: ... + def __getitem__(self, key: slice | int) -> T | R | list[T] | list[R]: + if isinstance(key, slice): + return list(self._get_slice(key)) + return self._get_one(key) + + @overload + def __iter__(self: PaginatedIterator[T, T]) -> Iterator[T]: ... + @overload + def __iter__(self: PaginatedIterator[T, R]) -> Iterator[R]: ... + def __iter__(self) -> Iterator[T] | Iterator[R]: + return self._get_slice(slice(0, None, 1)) + + +# TODO: should be Call, not WeaveObject +CallsIter = PaginatedIterator[CallSchema, WeaveObject] + + +def _make_calls_iterator( + server: TraceServerInterface, + project_id: str, + filter: CallsFilter, + include_costs: bool = False, +) -> CallsIter: + def fetch_func(offset: int, limit: int) -> list[CallSchema]: + response = server.calls_query( + CallsQueryReq( + project_id=project_id, + filter=filter, + offset=offset, + limit=limit, + include_costs=include_costs, + ) + ) + return response.calls + + # TODO: Should be Call, not WeaveObject + def transform_func(call: CallSchema) -> WeaveObject: + entity, project = project_id.split("/") + return make_client_call(entity, project, call, server) + + return PaginatedIterator(fetch_func, transform_func=transform_func) + class OpNameError(ValueError): """Raised when an op name is invalid.""" @@ -284,7 +406,7 @@ def children(self) -> CallsIter: ) client = weave_client_context.require_weave_client() - return CallsIter( + return _make_calls_iterator( client.server, self.project_id, CallsFilter(parent_ids=[self.id]), @@ -362,80 +484,6 @@ def _apply_scorer(self, scorer_op: Op) -> None: ) -class CallsIter: - server: TraceServerInterface - filter: CallsFilter - include_costs: bool - - def __init__( - self, - server: TraceServerInterface, - project_id: str, - filter: CallsFilter, - include_costs: bool = False, - ) -> None: - self.server = server - self.project_id = project_id - self.filter = filter - self._page_size = 1000 - self.include_costs = include_costs - - # seems like this caching should be on the server, but it's here for now... - @lru_cache - def _fetch_page(self, index: int) -> list[CallSchema]: - # caching in here means that any other CallsIter objects would also - # benefit from the cache - response = self.server.calls_query( - CallsQueryReq( - project_id=self.project_id, - filter=self.filter, - offset=index * self._page_size, - limit=self._page_size, - include_costs=self.include_costs, - ) - ) - return response.calls - - def _get_one(self, index: int) -> WeaveObject: - if index < 0: - raise IndexError("Negative indexing not supported") - - page_index = index // self._page_size - page_offset = index % self._page_size - - calls = self._fetch_page(page_index) - if page_offset >= len(calls): - raise IndexError(f"Index {index} out of range") - - call = calls[page_offset] - entity, project = self.project_id.split("/") - return make_client_call(entity, project, call, self.server) - - def _get_slice(self, key: slice) -> Iterator[WeaveObject]: - if (start := key.start or 0) < 0: - raise ValueError("Negative start not supported") - if (stop := key.stop) is not None and stop < 0: - raise ValueError("Negative stop not supported") - if (step := key.step or 1) < 0: - raise ValueError("Negative step not supported") - - i = start - while stop is None or i < stop: - try: - yield self._get_one(i) - except IndexError: - break - i += step - - def __getitem__(self, key: slice | int) -> WeaveObject | list[WeaveObject]: - if isinstance(key, slice): - return list(self._get_slice(key)) - return self._get_one(key) - - def __iter__(self) -> Iterator[WeaveObject]: - return self._get_slice(slice(0, None, 1)) - - def make_client_call( entity: str, project: str, server_call: CallSchema, server: TraceServerInterface ) -> WeaveObject: @@ -642,8 +690,11 @@ def get_calls( if filter is None: filter = CallsFilter() - return CallsIter( - self.server, self._project_id(), filter, include_costs or False + return _make_calls_iterator( + self.server, + self._project_id(), + filter, + include_costs, ) @deprecated(new_name="get_calls") From 0525621c503ee6082a2d809f17bef64b04ea71ef Mon Sep 17 00:00:00 2001 From: Andrew Truong Date: Wed, 11 Dec 2024 17:05:51 -0500 Subject: [PATCH 46/60] feat(weave): Support op configuration for autopatched functions (starting with OpenAI) (#3197) --- tests/conftest.py | 12 +- .../test_configuration_with_dicts.yaml | 102 +++++++++++++ ...est_disabled_integration_doesnt_patch.yaml | 102 +++++++++++++ .../test_enabled_integration_patches.yaml | 102 +++++++++++++ .../test_passthrough_op_kwargs.yaml | 102 +++++++++++++ tests/integrations/openai/test_autopatch.py | 116 +++++++++++++++ weave/integrations/openai/openai_sdk.py | 136 +++++++++++------- weave/scorers/llm_utils.py | 4 - weave/trace/api.py | 9 +- weave/trace/autopatch.py | 59 +++++++- weave/trace/patcher.py | 8 ++ weave/trace/weave_init.py | 6 +- 12 files changed, 689 insertions(+), 69 deletions(-) create mode 100644 tests/integrations/openai/cassettes/test_autopatch/test_configuration_with_dicts.yaml create mode 100644 tests/integrations/openai/cassettes/test_autopatch/test_disabled_integration_doesnt_patch.yaml create mode 100644 tests/integrations/openai/cassettes/test_autopatch/test_enabled_integration_patches.yaml create mode 100644 tests/integrations/openai/cassettes/test_autopatch/test_passthrough_op_kwargs.yaml create mode 100644 tests/integrations/openai/test_autopatch.py diff --git a/tests/conftest.py b/tests/conftest.py index b28187a3833..85e9b53c36b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -477,7 +477,9 @@ def __getattribute__(self, name): return ServerRecorder(server) -def create_client(request) -> weave_init.InitializedClient: +def create_client( + request, autopatch_settings: typing.Optional[autopatch.AutopatchSettings] = None +) -> weave_init.InitializedClient: inited_client = None weave_server_flag = request.config.getoption("--weave-server") server: tsi.TraceServerInterface @@ -513,7 +515,7 @@ def create_client(request) -> weave_init.InitializedClient: entity, project, make_server_recorder(server) ) inited_client = weave_init.InitializedClient(client) - autopatch.autopatch() + autopatch.autopatch(autopatch_settings) return inited_client @@ -527,6 +529,7 @@ def client(request): yield inited_client.client finally: inited_client.reset() + autopatch.reset_autopatch() @pytest.fixture() @@ -534,12 +537,13 @@ def client_creator(request): """This fixture is useful for delaying the creation of the client (ex. when you want to set settings first)""" @contextlib.contextmanager - def client(): - inited_client = create_client(request) + def client(autopatch_settings: typing.Optional[autopatch.AutopatchSettings] = None): + inited_client = create_client(request, autopatch_settings) try: yield inited_client.client finally: inited_client.reset() + autopatch.reset_autopatch() yield client diff --git a/tests/integrations/openai/cassettes/test_autopatch/test_configuration_with_dicts.yaml b/tests/integrations/openai/cassettes/test_autopatch/test_configuration_with_dicts.yaml new file mode 100644 index 00000000000..7245829a0b3 --- /dev/null +++ b/tests/integrations/openai/cassettes/test_autopatch/test_configuration_with_dicts.yaml @@ -0,0 +1,102 @@ +interactions: +- request: + body: '{"messages":[{"role":"user","content":"tell me a joke"}],"model":"gpt-4o"}' + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate, zstd + connection: + - keep-alive + content-length: + - '74' + content-type: + - application/json + host: + - api.openai.com + user-agent: + - OpenAI/Python 1.57.2 + x-stainless-arch: + - arm64 + x-stainless-async: + - 'false' + x-stainless-lang: + - python + x-stainless-os: + - MacOS + x-stainless-package-version: + - 1.57.2 + x-stainless-retry-count: + - '0' + x-stainless-runtime: + - CPython + x-stainless-runtime-version: + - 3.13.0rc2 + method: POST + uri: https://api.openai.com/v1/chat/completions + response: + body: + string: !!binary | + H4sIAAAAAAAAAwAAAP//jFJNa9wwEL37V0x1yWVd7P3KspcSSKE5thvooSlGK40tJbJGSOOSNOx/ + L/Z+2KEp9KLDe/Me743mNQMQVostCGUkqza4/Eav1f1O/v4aVvvbL/Pdt7tVHUp1s/vsnhsx6xW0 + f0TFZ9VHRW1wyJb8kVYRJWPvWl4vFpvNYl2UA9GSRtfLmsD5kvJ5MV/mxSYv1iehIaswiS38yAAA + Xoe3j+g1PostFLMz0mJKskGxvQwBiEiuR4RMySaWnsVsJBV5Rj+k/m5eQJO/YkhP6JDJJ6htYxhQ + KgPEBuOnB//g7w2eJ438hcAGoek4fZgaR6y7JPtevnPuhB8uSR01IdI+nfgLXltvk6kiykS+T5WY + ghjYQwbwc9hI96akCJHawBXTE/resCyPdmL8ggm5PJFMLN2Iz1ezd9wqjSytS5ONCiWVQT0qx/XL + TluaENmk899h3vM+9ra++R/7kVAKA6OuQkRt1dvC41jE/kD/NXbZ8RBYpJfE2Fa19Q3GEO3xRupQ + qWslC9xLJUV2yP4AAAD//wMA4O+DUSwDAAA= + headers: + CF-Cache-Status: + - DYNAMIC + CF-RAY: + - 8f01fe3aabd037cf-YYZ + Connection: + - keep-alive + Content-Encoding: + - gzip + Content-Type: + - application/json + Date: + - Wed, 11 Dec 2024 02:20:01 GMT + Server: + - cloudflare + Set-Cookie: + - __cf_bm=xqe_jHZdTV5LijJQYQ3GMY5MjtVrCyxbFO4glgLvgD0-1733883601-1.0.1.1-p.DDUca_cHppJu2hXzzA0CXU1mtalxHUNfBWVgPIQj.UkU603pbNscCvSIi4_Zjlz9Zuc3.hjlvoyZxcDBJTsw; + path=/; expires=Wed, 11-Dec-24 02:50:01 GMT; domain=.api.openai.com; HttpOnly; + Secure; SameSite=None + - _cfuvid=WEjxXqkGswaEDhllTROGX_go9tgaWNJcUJ3cCd50xDI-1733883601764-0.0.1.1-604800000; + path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None + Transfer-Encoding: + - chunked + X-Content-Type-Options: + - nosniff + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + openai-organization: + - wandb + openai-processing-ms: + - '607' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + x-ratelimit-limit-requests: + - '10000' + x-ratelimit-limit-tokens: + - '30000000' + x-ratelimit-remaining-requests: + - '9999' + x-ratelimit-remaining-tokens: + - '29999979' + x-ratelimit-reset-requests: + - 6ms + x-ratelimit-reset-tokens: + - 0s + x-request-id: + - req_8592a74b531c806f65c63c7471101cb6 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/integrations/openai/cassettes/test_autopatch/test_disabled_integration_doesnt_patch.yaml b/tests/integrations/openai/cassettes/test_autopatch/test_disabled_integration_doesnt_patch.yaml new file mode 100644 index 00000000000..1895cdcd5f2 --- /dev/null +++ b/tests/integrations/openai/cassettes/test_autopatch/test_disabled_integration_doesnt_patch.yaml @@ -0,0 +1,102 @@ +interactions: +- request: + body: '{"messages":[{"role":"user","content":"tell me a joke"}],"model":"gpt-4o"}' + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate, zstd + connection: + - keep-alive + content-length: + - '74' + content-type: + - application/json + host: + - api.openai.com + user-agent: + - OpenAI/Python 1.57.2 + x-stainless-arch: + - arm64 + x-stainless-async: + - 'false' + x-stainless-lang: + - python + x-stainless-os: + - MacOS + x-stainless-package-version: + - 1.57.2 + x-stainless-retry-count: + - '0' + x-stainless-runtime: + - CPython + x-stainless-runtime-version: + - 3.13.0rc2 + method: POST + uri: https://api.openai.com/v1/chat/completions + response: + body: + string: !!binary | + H4sIAAAAAAAAAwAAAP//jFLJbtswEL3rK6a85GIV8lYvl6KX9NQFrYEckkKgyZHImOII5KiJEfjf + C8qLHDQFeuHhbXgzw5cMQFgt1iCUkaya1uWf9Hyuv5H29ebu89Pt80Z9//H0RS2nX/e3LEbJQdtH + VHx2vVfUtA7Zkj/SKqBkTKnjxXS6XCwWk3FPNKTRJVvdcj6jfFJMZnmxzIsPJ6MhqzCKNdxnAAAv + /Zsqeo3PYg3F6Iw0GKOsUawvIgARyCVEyBhtZOmPdU+kIs/o+9Y/u4AjMBjwJoIEZ2vDuUEZGDU8 + 0g6hogB76tYP/sHfmT1o8jcMcYcOmXyEKlkApTJAbDB8TMKNwbPSyN8IbBDqjuO76xoBqy7KtAXf + OXfCD5e5HNVtoG088Re8st5GUwaUkXyaITK1omcPGcCvfn/dq5WINlDTcsm0Q58Cx+NjnBgONpCT + 2YlkYukGfDofvZFWamRpXbzav1BSGdSDcziW7LSlKyK7mvnvMm9lH+e2vv6f+IFQCltGXbYBtVWv + Bx5kAdN3/pfssuO+sIj7yNiUlfU1hjbY44+q2nKl54XSq1WxFdkh+wMAAP//AwAWTTnuWgMAAA== + headers: + CF-Cache-Status: + - DYNAMIC + CF-RAY: + - 8f016eadbff439d2-YYZ + Connection: + - keep-alive + Content-Encoding: + - gzip + Content-Type: + - application/json + Date: + - Wed, 11 Dec 2024 00:42:01 GMT + Server: + - cloudflare + Set-Cookie: + - __cf_bm=8FO1yMjc3pMQWRpWrkIe5mcs39GLeqQPmgHQq0YTT8s-1733877721-1.0.1.1-i4G06DBN08aH1F1H73U_TB9OLK3jLsV1jXydB1cQ4Hqx7I.r8xDn.7hFRZe2hy3D_nABTG1nDcdDoXL_wYiqug; + path=/; expires=Wed, 11-Dec-24 01:12:01 GMT; domain=.api.openai.com; HttpOnly; + Secure; SameSite=None + - _cfuvid=jxwySgtriPkUP8L2os1nb_gRq_SSUo3yWFUyJmHPmGY-1733877721989-0.0.1.1-604800000; + path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None + Transfer-Encoding: + - chunked + X-Content-Type-Options: + - nosniff + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + openai-organization: + - wandb + openai-processing-ms: + - '652' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + x-ratelimit-limit-requests: + - '10000' + x-ratelimit-limit-tokens: + - '30000000' + x-ratelimit-remaining-requests: + - '9999' + x-ratelimit-remaining-tokens: + - '29999979' + x-ratelimit-reset-requests: + - 6ms + x-ratelimit-reset-tokens: + - 0s + x-request-id: + - req_1c86d4fda2ad715edfd41bcd2f4bdd89 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/integrations/openai/cassettes/test_autopatch/test_enabled_integration_patches.yaml b/tests/integrations/openai/cassettes/test_autopatch/test_enabled_integration_patches.yaml new file mode 100644 index 00000000000..f0cdca54158 --- /dev/null +++ b/tests/integrations/openai/cassettes/test_autopatch/test_enabled_integration_patches.yaml @@ -0,0 +1,102 @@ +interactions: +- request: + body: '{"messages":[{"role":"user","content":"tell me a joke"}],"model":"gpt-4o"}' + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate, zstd + connection: + - keep-alive + content-length: + - '74' + content-type: + - application/json + host: + - api.openai.com + user-agent: + - OpenAI/Python 1.57.2 + x-stainless-arch: + - arm64 + x-stainless-async: + - 'false' + x-stainless-lang: + - python + x-stainless-os: + - MacOS + x-stainless-package-version: + - 1.57.2 + x-stainless-retry-count: + - '0' + x-stainless-runtime: + - CPython + x-stainless-runtime-version: + - 3.13.0rc2 + method: POST + uri: https://api.openai.com/v1/chat/completions + response: + body: + string: !!binary | + H4sIAAAAAAAAAwAAAP//jFLBjtMwEL3nKwZfuDQoTXc3VS8IBOIACCQOHHZR5NrTxDTxWJ4J2rDq + v6Ok2SYrFomLD+/Ne3pvxg8JgHJW7UCZWotpQ5O+sdfXaL9+/PDp+L6gG3ffyudv735v+y82eLUa + FLT/iUYeVa8MtaFBcTTRJqIWHFzXxWazLYoiz0eiJYvNIKuCpFeU5ll+lWbbNLuZhDU5g6x2cJsA + ADyM7xDRW7xXO8hWj0iLzLpCtbsMAahIzYAozexYtBe1mklDXtCPqb/XPVjyLwXYOPTiWBgkdiyg + hVp+fefv/Fs0umMEqbGHVh8RugD4C2MvtfPVi6V3xEPHeqjmu6aZ8NMlbENViLTnib/gB+cd12VE + zeSHYCwU1MieEoAf41K6Jz1ViNQGKYWO6AfD9fpsp+YrLMh8IoVENzOeb1bPuJUWRbuGF0tVRpsa + 7aycL6A762hBJIvOf4d5zvvc2/nqf+xnwhgMgrYMEa0zTwvPYxGHP/qvscuOx8CKexZsy4PzFcYQ + 3fmbHEJpCqMz3GujVXJK/gAAAP//AwAyhdwOLwMAAA== + headers: + CF-Cache-Status: + - DYNAMIC + CF-RAY: + - 8f016eb36bb3a240-YYZ + Connection: + - keep-alive + Content-Encoding: + - gzip + Content-Type: + - application/json + Date: + - Wed, 11 Dec 2024 00:42:02 GMT + Server: + - cloudflare + Set-Cookie: + - __cf_bm=Q_ATX8JU4jFqXJPdwlneOua9wmNmAaASyAfcbPyPqng-1733877722-1.0.1.1-eTMEvBW7oqQa2i3l.Or2I3LF_cCESxfseq.S9DBr8dAJWsVoFfPxKtr5vMaO6yj4hRW8XOSOHcgIcwwqbHrLbg; + path=/; expires=Wed, 11-Dec-24 01:12:02 GMT; domain=.api.openai.com; HttpOnly; + Secure; SameSite=None + - _cfuvid=2ak.tRpn6uEHbM8GrWy_ALtrN34jVSNIJI1mFG2etvM-1733877722703-0.0.1.1-604800000; + path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None + Transfer-Encoding: + - chunked + X-Content-Type-Options: + - nosniff + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + openai-organization: + - wandb + openai-processing-ms: + - '476' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + x-ratelimit-limit-requests: + - '10000' + x-ratelimit-limit-tokens: + - '30000000' + x-ratelimit-remaining-requests: + - '9999' + x-ratelimit-remaining-tokens: + - '29999979' + x-ratelimit-reset-requests: + - 6ms + x-ratelimit-reset-tokens: + - 0s + x-request-id: + - req_52e061e1cc55cdd8847a7ba9342f1a14 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/integrations/openai/cassettes/test_autopatch/test_passthrough_op_kwargs.yaml b/tests/integrations/openai/cassettes/test_autopatch/test_passthrough_op_kwargs.yaml new file mode 100644 index 00000000000..646c57c6123 --- /dev/null +++ b/tests/integrations/openai/cassettes/test_autopatch/test_passthrough_op_kwargs.yaml @@ -0,0 +1,102 @@ +interactions: +- request: + body: '{"messages":[{"role":"user","content":"tell me a joke"}],"model":"gpt-4o"}' + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate, zstd + connection: + - keep-alive + content-length: + - '74' + content-type: + - application/json + host: + - api.openai.com + user-agent: + - OpenAI/Python 1.57.2 + x-stainless-arch: + - arm64 + x-stainless-async: + - 'false' + x-stainless-lang: + - python + x-stainless-os: + - MacOS + x-stainless-package-version: + - 1.57.2 + x-stainless-retry-count: + - '0' + x-stainless-runtime: + - CPython + x-stainless-runtime-version: + - 3.13.0rc2 + method: POST + uri: https://api.openai.com/v1/chat/completions + response: + body: + string: !!binary | + H4sIAAAAAAAAAwAAAP//jFLLbtswELzrK7a89GIVsmLXsS9Fr0UvBQIERVMINLkS2VBcglwVcQL/ + e0H5IQVNgV54mNkZzCz3pQAQVosdCGUkqz648rNer7G2z3fL8ITq+X6lvn7x/SHhN/d9LxZZQftf + qPii+qCoDw7Zkj/RKqJkzK7Lzc3N7WazqeuR6Emjy7IucLmisq7qVVndltXHs9CQVZjEDn4UAAAv + 45sjeo1PYgfV4oL0mJLsUOyuQwAiksuIkCnZxNKzWEykIs/ox9T35gCa/HuG9IgOmXyC1naGAaUy + QGwwfnrwD/7O4GXSyN8IbBC6gdO7uXHEdkgy9/KDc2f8eE3qqAuR9unMX/HWeptME1Em8jlVYgpi + ZI8FwM9xI8OrkiJE6gM3TI/os+FyebIT0xfMyNWZZGLpJrxeL95wazSytC7NNiqUVAb1pJzWLwdt + aUYUs85/h3nL+9Tb+u5/7CdCKQyMugkRtVWvC09jEfOB/mvsuuMxsEiHxNg3rfUdxhDt6Uba0Gz1 + ulJ6u632ojgWfwAAAP//AwCOwDMjLAMAAA== + headers: + CF-Cache-Status: + - DYNAMIC + CF-RAY: + - 8f016eb76b71ac9a-YYZ + Connection: + - keep-alive + Content-Encoding: + - gzip + Content-Type: + - application/json + Date: + - Wed, 11 Dec 2024 00:42:03 GMT + Server: + - cloudflare + Set-Cookie: + - __cf_bm=r.xSSsYQNFPvMiizFSvjQiecNA6Q1wQa0VR1YElfXi4-1733877723-1.0.1.1-GVW0i7wrpHCQSY5eXu7sIQgxYWl6jfeSordQ7JFxV3lO6UfFhwxRT92bBP4DfnrSYpBpRw3k4aONAURyvKctiQ; + path=/; expires=Wed, 11-Dec-24 01:12:03 GMT; domain=.api.openai.com; HttpOnly; + Secure; SameSite=None + - _cfuvid=CQJVOdASzL9ency5_q6SDaInTsvpjA240cIxf.AUwXM-1733877723385-0.0.1.1-604800000; + path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None + Transfer-Encoding: + - chunked + X-Content-Type-Options: + - nosniff + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + openai-organization: + - wandb + openai-processing-ms: + - '523' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + x-ratelimit-limit-requests: + - '10000' + x-ratelimit-limit-tokens: + - '30000000' + x-ratelimit-remaining-requests: + - '9999' + x-ratelimit-remaining-tokens: + - '29999979' + x-ratelimit-reset-requests: + - 6ms + x-ratelimit-reset-tokens: + - 0s + x-request-id: + - req_c9c57cfa6f37a99aaf0abac013237ed6 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/integrations/openai/test_autopatch.py b/tests/integrations/openai/test_autopatch.py new file mode 100644 index 00000000000..2c2f5201d3f --- /dev/null +++ b/tests/integrations/openai/test_autopatch.py @@ -0,0 +1,116 @@ +# This is included here for convenience. Instead of creating a dummy API, we can test +# autopatching against the actual OpenAI API. + +from typing import Any + +import pytest +from openai import OpenAI + +from weave.integrations.openai import openai_sdk +from weave.trace.autopatch import AutopatchSettings, IntegrationSettings, OpSettings + + +@pytest.mark.skip_clickhouse_client # TODO:VCR recording does not seem to allow us to make requests to the clickhouse db in non-recording mode +@pytest.mark.vcr( + filter_headers=["authorization"], allowed_hosts=["api.wandb.ai", "localhost"] +) +def test_disabled_integration_doesnt_patch(client_creator): + autopatch_settings = AutopatchSettings( + openai=IntegrationSettings(enabled=False), + ) + + with client_creator(autopatch_settings=autopatch_settings) as client: + oaiclient = OpenAI() + oaiclient.chat.completions.create( + model="gpt-4o", + messages=[{"role": "user", "content": "tell me a joke"}], + ) + + calls = list(client.get_calls()) + assert len(calls) == 0 + + +@pytest.mark.skip_clickhouse_client # TODO:VCR recording does not seem to allow us to make requests to the clickhouse db in non-recording mode +@pytest.mark.vcr( + filter_headers=["authorization"], allowed_hosts=["api.wandb.ai", "localhost"] +) +def test_enabled_integration_patches(client_creator): + autopatch_settings = AutopatchSettings( + openai=IntegrationSettings(enabled=True), + ) + + with client_creator(autopatch_settings=autopatch_settings) as client: + oaiclient = OpenAI() + oaiclient.chat.completions.create( + model="gpt-4o", + messages=[{"role": "user", "content": "tell me a joke"}], + ) + + calls = list(client.get_calls()) + assert len(calls) == 1 + + +@pytest.mark.skip_clickhouse_client # TODO:VCR recording does not seem to allow us to make requests to the clickhouse db in non-recording mode +@pytest.mark.vcr( + filter_headers=["authorization"], allowed_hosts=["api.wandb.ai", "localhost"] +) +def test_passthrough_op_kwargs(client_creator): + def redact_inputs(inputs: dict[str, Any]) -> dict[str, Any]: + return dict.fromkeys(inputs, "REDACTED") + + autopatch_settings = AutopatchSettings( + openai=IntegrationSettings( + op_settings=OpSettings( + postprocess_inputs=redact_inputs, + ) + ) + ) + + # Explicitly reset the patcher here to pretend like we're starting fresh. We need + # to do this because `_openai_patcher` is a global variable that is shared across + # tests. If we don't reset it, it will retain the state from the previous test, + # which can cause this test to fail. + openai_sdk._openai_patcher = None + + with client_creator(autopatch_settings=autopatch_settings) as client: + oaiclient = OpenAI() + oaiclient.chat.completions.create( + model="gpt-4o", + messages=[{"role": "user", "content": "tell me a joke"}], + ) + + calls = list(client.get_calls()) + assert len(calls) == 1 + + call = calls[0] + assert all(v == "REDACTED" for v in call.inputs.values()) + + +@pytest.mark.skip_clickhouse_client # TODO:VCR recording does not seem to allow us to make requests to the clickhouse db in non-recording mode +@pytest.mark.vcr( + filter_headers=["authorization"], allowed_hosts=["api.wandb.ai", "localhost"] +) +def test_configuration_with_dicts(client_creator): + def redact_inputs(inputs: dict[str, Any]) -> dict[str, Any]: + return dict.fromkeys(inputs, "REDACTED") + + autopatch_settings = { + "openai": { + "op_settings": {"postprocess_inputs": redact_inputs}, + } + } + + openai_sdk._openai_patcher = None + + with client_creator(autopatch_settings=autopatch_settings) as client: + oaiclient = OpenAI() + oaiclient.chat.completions.create( + model="gpt-4o", + messages=[{"role": "user", "content": "tell me a joke"}], + ) + + calls = list(client.get_calls()) + assert len(calls) == 1 + + call = calls[0] + assert all(v == "REDACTED" for v in call.inputs.values()) diff --git a/weave/integrations/openai/openai_sdk.py b/weave/integrations/openai/openai_sdk.py index 7814700d4d3..a1e3a9b5831 100644 --- a/weave/integrations/openai/openai_sdk.py +++ b/weave/integrations/openai/openai_sdk.py @@ -1,15 +1,20 @@ +from __future__ import annotations + import importlib from functools import wraps -from typing import TYPE_CHECKING, Any, Callable, Optional +from typing import TYPE_CHECKING, Any, Callable import weave +from weave.trace.autopatch import IntegrationSettings, OpSettings from weave.trace.op import Op, ProcessedInputs from weave.trace.op_extensions.accumulator import add_accumulator -from weave.trace.patcher import MultiPatcher, SymbolPatcher +from weave.trace.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher if TYPE_CHECKING: from openai.types.chat import ChatCompletionChunk +_openai_patcher: MultiPatcher | None = None + def maybe_unwrap_api_response(value: Any) -> Any: """If the caller requests a raw response, we unwrap the APIResponse object. @@ -43,9 +48,7 @@ def maybe_unwrap_api_response(value: Any) -> Any: return value -def openai_on_finish_post_processor( - value: Optional["ChatCompletionChunk"], -) -> Optional[dict]: +def openai_on_finish_post_processor(value: ChatCompletionChunk | None) -> dict | None: from openai.types.chat import ChatCompletion, ChatCompletionChunk from openai.types.chat.chat_completion_chunk import ( ChoiceDeltaFunctionCall, @@ -60,8 +63,8 @@ def openai_on_finish_post_processor( value = maybe_unwrap_api_response(value) def _get_function_call( - function_call: Optional[ChoiceDeltaFunctionCall], - ) -> Optional[FunctionCall]: + function_call: ChoiceDeltaFunctionCall | None, + ) -> FunctionCall | None: if function_call is None: return function_call if isinstance(function_call, ChoiceDeltaFunctionCall): @@ -73,8 +76,8 @@ def _get_function_call( return None def _get_tool_calls( - tool_calls: Optional[list[ChoiceDeltaToolCall]], - ) -> Optional[list[ChatCompletionMessageToolCall]]: + tool_calls: list[ChoiceDeltaToolCall] | None, + ) -> list[ChatCompletionMessageToolCall] | None: if tool_calls is None: return tool_calls @@ -128,10 +131,10 @@ def _get_tool_calls( def openai_accumulator( - acc: Optional["ChatCompletionChunk"], - value: "ChatCompletionChunk", + acc: ChatCompletionChunk | None, + value: ChatCompletionChunk, skip_last: bool = False, -) -> "ChatCompletionChunk": +) -> ChatCompletionChunk: from openai.types.chat import ChatCompletionChunk from openai.types.chat.chat_completion_chunk import ( ChoiceDeltaFunctionCall, @@ -285,7 +288,7 @@ def should_use_accumulator(inputs: dict) -> bool: def openai_on_input_handler( func: Op, args: tuple, kwargs: dict -) -> Optional[ProcessedInputs]: +) -> ProcessedInputs | None: if len(args) == 2 and isinstance(args[1], weave.EasyPrompt): original_args = args original_kwargs = kwargs @@ -305,20 +308,16 @@ def openai_on_input_handler( return None -def create_wrapper_sync( - name: str, -) -> Callable[[Callable], Callable]: +def create_wrapper_sync(settings: OpSettings) -> Callable[[Callable], Callable]: def wrapper(fn: Callable) -> Callable: "We need to do this so we can check if `stream` is used" def _add_stream_options(fn: Callable) -> Callable: @wraps(fn) def _wrapper(*args: Any, **kwargs: Any) -> Any: - if bool(kwargs.get("stream")) and kwargs.get("stream_options") is None: + if kwargs.get("stream") and kwargs.get("stream_options") is None: kwargs["stream_options"] = {"include_usage": True} - return fn( - *args, **kwargs - ) # This is where the final execution of fn is happening. + return fn(*args, **kwargs) return _wrapper @@ -327,8 +326,8 @@ def _openai_stream_options_is_set(inputs: dict) -> bool: return True return False - op = weave.op()(_add_stream_options(fn)) - op.name = name # type: ignore + op_kwargs = settings.model_dump() + op = weave.op(_add_stream_options(fn), **op_kwargs) op._set_on_input_handler(openai_on_input_handler) return add_accumulator( op, # type: ignore @@ -345,16 +344,14 @@ def _openai_stream_options_is_set(inputs: dict) -> bool: # Surprisingly, the async `client.chat.completions.create` does not pass # `inspect.iscoroutinefunction`, so we can't dispatch on it and must write # it manually here... -def create_wrapper_async( - name: str, -) -> Callable[[Callable], Callable]: +def create_wrapper_async(settings: OpSettings) -> Callable[[Callable], Callable]: def wrapper(fn: Callable) -> Callable: "We need to do this so we can check if `stream` is used" def _add_stream_options(fn: Callable) -> Callable: @wraps(fn) async def _wrapper(*args: Any, **kwargs: Any) -> Any: - if bool(kwargs.get("stream")) and kwargs.get("stream_options") is None: + if kwargs.get("stream") and kwargs.get("stream_options") is None: kwargs["stream_options"] = {"include_usage": True} return await fn(*args, **kwargs) @@ -365,8 +362,8 @@ def _openai_stream_options_is_set(inputs: dict) -> bool: return True return False - op = weave.op()(_add_stream_options(fn)) - op.name = name # type: ignore + op_kwargs = settings.model_dump() + op = weave.op(_add_stream_options(fn), **op_kwargs) op._set_on_input_handler(openai_on_input_handler) return add_accumulator( op, # type: ignore @@ -380,28 +377,61 @@ def _openai_stream_options_is_set(inputs: dict) -> bool: return wrapper -symbol_patchers = [ - # Patch the Completions.create method - SymbolPatcher( - lambda: importlib.import_module("openai.resources.chat.completions"), - "Completions.create", - create_wrapper_sync(name="openai.chat.completions.create"), - ), - SymbolPatcher( - lambda: importlib.import_module("openai.resources.chat.completions"), - "AsyncCompletions.create", - create_wrapper_async(name="openai.chat.completions.create"), - ), - SymbolPatcher( - lambda: importlib.import_module("openai.resources.beta.chat.completions"), - "Completions.parse", - create_wrapper_sync(name="openai.beta.chat.completions.parse"), - ), - SymbolPatcher( - lambda: importlib.import_module("openai.resources.beta.chat.completions"), - "AsyncCompletions.parse", - create_wrapper_async(name="openai.beta.chat.completions.parse"), - ), -] - -openai_patcher = MultiPatcher(symbol_patchers) # type: ignore +def get_openai_patcher( + settings: IntegrationSettings | None = None, +) -> MultiPatcher | NoOpPatcher: + if settings is None: + settings = IntegrationSettings() + + if not settings.enabled: + return NoOpPatcher() + + global _openai_patcher + if _openai_patcher is not None: + return _openai_patcher + + base = settings.op_settings + + completions_create_settings = base.model_copy( + update={"name": base.name or "openai.chat.completions.create"} + ) + async_completions_create_settings = base.model_copy( + update={"name": base.name or "openai.chat.completions.create"} + ) + completions_parse_settings = base.model_copy( + update={"name": base.name or "openai.beta.chat.completions.parse"} + ) + async_completions_parse_settings = base.model_copy( + update={"name": base.name or "openai.beta.chat.completions.parse"} + ) + + _openai_patcher = MultiPatcher( + [ + SymbolPatcher( + lambda: importlib.import_module("openai.resources.chat.completions"), + "Completions.create", + create_wrapper_sync(settings=completions_create_settings), + ), + SymbolPatcher( + lambda: importlib.import_module("openai.resources.chat.completions"), + "AsyncCompletions.create", + create_wrapper_async(settings=async_completions_create_settings), + ), + SymbolPatcher( + lambda: importlib.import_module( + "openai.resources.beta.chat.completions" + ), + "Completions.parse", + create_wrapper_sync(settings=completions_parse_settings), + ), + SymbolPatcher( + lambda: importlib.import_module( + "openai.resources.beta.chat.completions" + ), + "AsyncCompletions.parse", + create_wrapper_async(settings=async_completions_parse_settings), + ), + ] + ) + + return _openai_patcher diff --git a/weave/scorers/llm_utils.py b/weave/scorers/llm_utils.py index 68ae2ccb366..eef6f018b0f 100644 --- a/weave/scorers/llm_utils.py +++ b/weave/scorers/llm_utils.py @@ -2,10 +2,6 @@ from typing import TYPE_CHECKING, Any, Union -from weave.trace.autopatch import autopatch - -autopatch() # ensure both weave patching and instructor patching are applied - OPENAI_DEFAULT_MODEL = "gpt-4o" OPENAI_DEFAULT_EMBEDDING_MODEL = "text-embedding-3-small" OPENAI_DEFAULT_MODERATION_MODEL = "text-moderation-latest" diff --git a/weave/trace/api.py b/weave/trace/api.py index ee8131b0875..294308cbb67 100644 --- a/weave/trace/api.py +++ b/weave/trace/api.py @@ -13,6 +13,7 @@ # There is probably a better place for this, but including here for now to get the fix in. from weave import type_handlers # noqa: F401 from weave.trace import urls, util, weave_client, weave_init +from weave.trace.autopatch import AutopatchSettings from weave.trace.constants import TRACE_OBJECT_EMOJI from weave.trace.context import call_context from weave.trace.context import weave_client_context as weave_client_context @@ -32,6 +33,7 @@ def init( project_name: str, *, settings: UserSettings | dict[str, Any] | None = None, + autopatch_settings: AutopatchSettings | None = None, ) -> weave_client.WeaveClient: """Initialize weave tracking, logging to a wandb project. @@ -52,7 +54,12 @@ def init( if should_disable_weave(): return weave_init.init_weave_disabled().client - return weave_init.init_weave(project_name).client + initialized_client = weave_init.init_weave( + project_name, + autopatch_settings=autopatch_settings, + ) + + return initialized_client.client @contextlib.contextmanager diff --git a/weave/trace/autopatch.py b/weave/trace/autopatch.py index 3a5dca14556..0619194a224 100644 --- a/weave/trace/autopatch.py +++ b/weave/trace/autopatch.py @@ -4,8 +4,54 @@ check if libraries are installed and imported and patch in the case that they are. """ +from typing import Any, Callable, Optional, Union -def autopatch() -> None: +from pydantic import BaseModel, Field, validate_call + +from weave.trace.weave_client import Call + + +class OpSettings(BaseModel): + """Op settings for a specific integration. + These currently subset the `op` decorator args to provide a consistent interface + when working with auto-patched functions. See the `op` decorator for more details.""" + + name: Optional[str] = None + call_display_name: Optional[Union[str, Callable[[Call], str]]] = None + postprocess_inputs: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None + postprocess_output: Optional[Callable[[Any], Any]] = None + + +class IntegrationSettings(BaseModel): + """Configuration for a specific integration.""" + + enabled: bool = True + op_settings: OpSettings = Field(default_factory=OpSettings) + + +class AutopatchSettings(BaseModel): + """Settings for auto-patching integrations.""" + + # These will be uncommented as we add support for more integrations. Note that + + # anthropic: IntegrationSettings = Field(default_factory=IntegrationSettings) + # cerebras: IntegrationSettings = Field(default_factory=IntegrationSettings) + # cohere: IntegrationSettings = Field(default_factory=IntegrationSettings) + # dspy: IntegrationSettings = Field(default_factory=IntegrationSettings) + # google_ai_studio: IntegrationSettings = Field(default_factory=IntegrationSettings) + # groq: IntegrationSettings = Field(default_factory=IntegrationSettings) + # instructor: IntegrationSettings = Field(default_factory=IntegrationSettings) + # langchain: IntegrationSettings = Field(default_factory=IntegrationSettings) + # litellm: IntegrationSettings = Field(default_factory=IntegrationSettings) + # llamaindex: IntegrationSettings = Field(default_factory=IntegrationSettings) + # mistral: IntegrationSettings = Field(default_factory=IntegrationSettings) + # notdiamond: IntegrationSettings = Field(default_factory=IntegrationSettings) + openai: IntegrationSettings = Field(default_factory=IntegrationSettings) + # vertexai: IntegrationSettings = Field(default_factory=IntegrationSettings) + + +@validate_call +def autopatch(settings: Optional[AutopatchSettings] = None) -> None: from weave.integrations.anthropic.anthropic_sdk import anthropic_patcher from weave.integrations.cerebras.cerebras_sdk import cerebras_patcher from weave.integrations.cohere.cohere_sdk import cohere_patcher @@ -20,10 +66,13 @@ def autopatch() -> None: from weave.integrations.llamaindex.llamaindex import llamaindex_patcher from weave.integrations.mistral import mistral_patcher from weave.integrations.notdiamond.tracing import notdiamond_patcher - from weave.integrations.openai.openai_sdk import openai_patcher + from weave.integrations.openai.openai_sdk import get_openai_patcher from weave.integrations.vertexai.vertexai_sdk import vertexai_patcher - openai_patcher.attempt_patch() + if settings is None: + settings = AutopatchSettings() + + get_openai_patcher(settings.openai).attempt_patch() mistral_patcher.attempt_patch() litellm_patcher.attempt_patch() llamaindex_patcher.attempt_patch() @@ -54,10 +103,10 @@ def reset_autopatch() -> None: from weave.integrations.llamaindex.llamaindex import llamaindex_patcher from weave.integrations.mistral import mistral_patcher from weave.integrations.notdiamond.tracing import notdiamond_patcher - from weave.integrations.openai.openai_sdk import openai_patcher + from weave.integrations.openai.openai_sdk import get_openai_patcher from weave.integrations.vertexai.vertexai_sdk import vertexai_patcher - openai_patcher.undo_patch() + get_openai_patcher().undo_patch() mistral_patcher.undo_patch() litellm_patcher.undo_patch() llamaindex_patcher.undo_patch() diff --git a/weave/trace/patcher.py b/weave/trace/patcher.py index 1567c4e2bb9..c1d0d653ffa 100644 --- a/weave/trace/patcher.py +++ b/weave/trace/patcher.py @@ -17,6 +17,14 @@ def undo_patch(self) -> bool: raise NotImplementedError() +class NoOpPatcher(Patcher): + def attempt_patch(self) -> bool: + return True + + def undo_patch(self) -> bool: + return True + + class MultiPatcher(Patcher): def __init__(self, patchers: Sequence[Patcher]) -> None: self.patchers = patchers diff --git a/weave/trace/weave_init.py b/weave/trace/weave_init.py index 563dcbdaed4..f51d42d5018 100644 --- a/weave/trace/weave_init.py +++ b/weave/trace/weave_init.py @@ -63,7 +63,9 @@ def get_entity_project_from_project_name(project_name: str) -> tuple[str, str]: def init_weave( - project_name: str, ensure_project_exists: bool = True + project_name: str, + ensure_project_exists: bool = True, + autopatch_settings: autopatch.AutopatchSettings | None = None, ) -> InitializedClient: global _current_inited_client if _current_inited_client is not None: @@ -120,7 +122,7 @@ def init_weave( # autopatching is only supported for the wandb client, because OpenAI calls are not # logged in local mode currently. When that's fixed, this autopatch call can be # moved to InitializedClient.__init__ - autopatch.autopatch() + autopatch.autopatch(autopatch_settings) username = get_username() try: From 16e47c3a8d804db1f7c8c80fe53b5b58082a6757 Mon Sep 17 00:00:00 2001 From: Jamie Rasmussen <112953339+jamie-rasmussen@users.noreply.github.com> Date: Wed, 11 Dec 2024 16:19:24 -0600 Subject: [PATCH 47/60] chore(ui): update UUID dependency to v11 (latest) (#3208) --- weave-js/package.json | 3 +-- weave-js/yarn.lock | 15 +++++---------- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/weave-js/package.json b/weave-js/package.json index cb57125143a..1f551021ed4 100644 --- a/weave-js/package.json +++ b/weave-js/package.json @@ -192,7 +192,6 @@ "@types/react-virtualized-auto-sizer": "^1.0.0", "@types/safe-json-stringify": "^1.1.2", "@types/styled-components": "^5.1.26", - "@types/uuid": "^9.0.1", "@types/wavesurfer.js": "^2.0.0", "@types/zen-observable": "^0.8.3", "@typescript-eslint/eslint-plugin": "5.35.1", @@ -237,7 +236,7 @@ "tslint-config-prettier": "^1.18.0", "tslint-plugin-prettier": "^2.3.0", "typescript": "4.7.4", - "uuid": "^9.0.0", + "uuid": "^11.0.3", "vite": "5.2.9", "vitest": "^1.6.0" }, diff --git a/weave-js/yarn.lock b/weave-js/yarn.lock index c7f9379e32a..6a5ec14e872 100644 --- a/weave-js/yarn.lock +++ b/weave-js/yarn.lock @@ -4776,11 +4776,6 @@ resolved "https://registry.yarnpkg.com/@types/unist/-/unist-2.0.7.tgz#5b06ad6894b236a1d2bd6b2f07850ca5c59cf4d6" integrity sha512-cputDpIbFgLUaGQn6Vqg3/YsJwxUwHLO13v3i5ouxT4lat0khip9AEWxtERujXV9wxIB1EyF97BSJFt6vpdI8g== -"@types/uuid@^9.0.1": - version "9.0.2" - resolved "https://registry.yarnpkg.com/@types/uuid/-/uuid-9.0.2.tgz#ede1d1b1e451548d44919dc226253e32a6952c4b" - integrity sha512-kNnC1GFBLuhImSnV7w4njQkUiJi0ZXUycu1rUaouPqiKlXkh77JKgdRnTAp1x5eBwcIwbtI+3otwzuIDEuDoxQ== - "@types/wavesurfer.js@^2.0.0": version "2.0.2" resolved "https://registry.yarnpkg.com/@types/wavesurfer.js/-/wavesurfer.js-2.0.2.tgz#b98a4d57ca24ee2028ae6dd5c2208b568bb73842" @@ -15032,6 +15027,11 @@ util-deprecate@^1.0.1, util-deprecate@^1.0.2, util-deprecate@~1.0.1: resolved "https://registry.yarnpkg.com/util-deprecate/-/util-deprecate-1.0.2.tgz#450d4dc9fa70de732762fbd2d4a28981419a0ccf" integrity sha512-EPD5q1uXyFxJpCrLnCc1nHnq3gOa6DZBocAIiI2TaSCA7VCJ1UJDMagCzIkXNsUYfD1daK//LTEQ8xiIbrHtcw== +uuid@^11.0.3: + version "11.0.3" + resolved "https://registry.yarnpkg.com/uuid/-/uuid-11.0.3.tgz#248451cac9d1a4a4128033e765d137e2b2c49a3d" + integrity sha512-d0z310fCWv5dJwnX1Y/MncBAqGMKEzlBb1AOf7z9K8ALnd0utBX/msg/fA0+sbyN1ihbMsLhrBlnl1ak7Wa0rg== + uuid@^2.0.2: version "2.0.3" resolved "https://registry.yarnpkg.com/uuid/-/uuid-2.0.3.tgz#67e2e863797215530dff318e5bf9dcebfd47b21a" @@ -15042,11 +15042,6 @@ uuid@^3.0.0, uuid@^3.4.0: resolved "https://registry.yarnpkg.com/uuid/-/uuid-3.4.0.tgz#b23e4358afa8a202fe7a100af1f5f883f02007ee" integrity sha512-HjSDRw6gZE5JMggctHBcjVak08+KEVhSIiDzFnT9S9aegmp85S/bReBVTb4QTFaRNptJ9kuYaNhnbNEOkbKb/A== -uuid@^9.0.0: - version "9.0.0" - resolved "https://registry.yarnpkg.com/uuid/-/uuid-9.0.0.tgz#592f550650024a38ceb0c562f2f6aa435761efb5" - integrity sha512-MXcSTerfPa4uqyzStbRoTgt5XIe3x5+42+q1sDuy3R5MDk66URdLMOZe5aPX/SQd+kuYAh0FdP/pO28IkQyTeg== - uvu@^0.5.0: version "0.5.6" resolved "https://registry.yarnpkg.com/uvu/-/uvu-0.5.6.tgz#2754ca20bcb0bb59b64e9985e84d2e81058502df" From 444c04dcdeb965b531d7508f5e82af56cfcb00f2 Mon Sep 17 00:00:00 2001 From: Jamie Rasmussen <112953339+jamie-rasmussen@users.noreply.github.com> Date: Wed, 11 Dec 2024 18:50:14 -0600 Subject: [PATCH 48/60] chore(ui): remove some unused code (#3157) --- .../PagePanelComponents/Home/Browse3.tsx | 138 +----------------- 1 file changed, 1 insertion(+), 137 deletions(-) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx index 3517f4d3b9c..761fd536930 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx @@ -1,15 +1,5 @@ import {ApolloProvider} from '@apollo/client'; -import {Home} from '@mui/icons-material'; -import { - AppBar, - Box, - Breadcrumbs, - Drawer, - IconButton, - Link as MaterialLink, - Toolbar, - Typography, -} from '@mui/material'; +import {Box, Drawer} from '@mui/material'; import { GridColumnVisibilityModel, GridFilterModel, @@ -21,9 +11,7 @@ import {LicenseInfo} from '@mui/x-license'; import {makeGorillaApolloClient} from '@wandb/weave/apollo'; import {EVALUATE_OP_NAME_POST_PYDANTIC} from '@wandb/weave/components/PagePanelComponents/Home/Browse3/pages/common/heuristics'; import {opVersionKeyToRefUri} from '@wandb/weave/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/utilities'; -import _ from 'lodash'; import React, { - ComponentProps, FC, useCallback, useEffect, @@ -33,7 +21,6 @@ import React, { } from 'react'; import useMousetrap from 'react-hook-mousetrap'; import { - Link as RouterLink, Redirect, Route, Switch, @@ -199,7 +186,6 @@ export const Browse3: FC<{ `/${URL_BROWSE3}`, ]}> @@ -211,7 +197,6 @@ export const Browse3: FC<{ }; const Browse3Mounted: FC<{ - hideHeader?: boolean; headerOffset?: number; navigateAwayFromProject?: () => void; }> = props => { @@ -225,37 +210,6 @@ const Browse3Mounted: FC<{ overflow: 'auto', flexDirection: 'column', }}> - {!props.hideHeader && ( - theme.zIndex.drawer + 1, - height: '60px', - flex: '0 0 auto', - position: 'static', - }}> - - - theme.palette.getContrastText(theme.palette.primary.main), - '&:hover': { - color: theme => - theme.palette.getContrastText(theme.palette.primary.dark), - }, - marginRight: theme => theme.spacing(2), - }}> - - - - - - )} @@ -1050,20 +1004,6 @@ const ComparePageBinding = () => { return ; }; -const AppBarLink = (props: ComponentProps) => ( - theme.palette.getContrastText(theme.palette.primary.main), - '&:hover': { - color: theme => - theme.palette.getContrastText(theme.palette.primary.dark), - }, - }} - {...props} - component={RouterLink} - /> -); - const PlaygroundPageBinding = () => { const params = useParamsDecoded(); return ( @@ -1074,79 +1014,3 @@ const PlaygroundPageBinding = () => { /> ); }; - -const Browse3Breadcrumbs: FC = props => { - const params = useParamsDecoded(); - const query = useURLSearchParamsDict(); - const filePathParts = query.path?.split('/') ?? []; - const refFields = query.extra?.split('/') ?? []; - - return ( - - {params.entity && ( - - {params.entity} - - )} - {params.project && ( - - {params.project} - - )} - {params.tab && ( - - {params.tab} - - )} - {params.itemName && ( - - {params.itemName} - - )} - {params.version && ( - - {params.version} - - )} - {filePathParts.map((part, idx) => ( - - {part} - - ))} - {_.range(0, refFields.length, 2).map(idx => ( - - - theme.palette.getContrastText(theme.palette.primary.main), - }}> - {refFields[idx]} - - - {refFields[idx + 1]} - - - ))} - - ); -}; From 96d1d0d0f48cd571ae7b7737de2fd58d663588d9 Mon Sep 17 00:00:00 2001 From: Griffin Tarpenning Date: Wed, 11 Dec 2024 17:55:26 -0800 Subject: [PATCH 49/60] chore(ui): Create scorer drawer style + small annotation drawer style tweaks (#3186) --- .../StructuredFeedback/HumanAnnotation.tsx | 4 +- .../ScorersPage/AnnotationScorerForm.tsx | 9 +- .../pages/ScorersPage/FormComponents.tsx | 2 +- .../pages/ScorersPage/NewScorerDrawer.tsx | 25 +++- .../pages/ScorersPage/ZodSchemaForm.tsx | 133 +++++++++++------- 5 files changed, 111 insertions(+), 62 deletions(-) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/StructuredFeedback/HumanAnnotation.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/StructuredFeedback/HumanAnnotation.tsx index 7facffe9556..2821c02affa 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/StructuredFeedback/HumanAnnotation.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/StructuredFeedback/HumanAnnotation.tsx @@ -415,7 +415,7 @@ export const TextFeedbackColumn = ({ placeholder="" /> {maxLength && ( -
+
{`Maximum characters: ${maxLength}`}
)} @@ -603,7 +603,7 @@ export const NumericalTextField: React.FC = ({ errorState={error} /> {(min != null || max != null) && ( -
+
{isInteger ? 'Integer required. ' : ''} {min != null && `Min: ${min}`} {min != null && max != null && ', '} diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/AnnotationScorerForm.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/AnnotationScorerForm.tsx index 9acbdfe6c2f..a478437facb 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/AnnotationScorerForm.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/AnnotationScorerForm.tsx @@ -1,5 +1,5 @@ import {Box} from '@material-ui/core'; -import React, {FC, useCallback, useState} from 'react'; +import React, {FC, useCallback, useEffect, useState} from 'react'; import {z} from 'zod'; import {createBaseObjectInstance} from '../wfReactInterface/baseObjectClassQuery'; @@ -28,7 +28,7 @@ const AnnotationScorerFormSchema = z.object({ }), z.object({ type: z.literal('String'), - 'Max length': z.number().optional(), + 'Maximum length': z.number().optional(), }), z.object({ type: z.literal('Select'), @@ -45,6 +45,9 @@ export const AnnotationScorerForm: FC< ScorerFormProps> > = ({data, onDataChange}) => { const [config, setConfig] = useState(data ?? DEFAULT_STATE); + useEffect(() => { + setConfig(data ?? DEFAULT_STATE); + }, [data]); const [isValid, setIsValid] = useState(false); const handleConfigChange = useCallback( @@ -113,7 +116,7 @@ function convertTypeExtrasToJsonSchema( const typeSchema = obj.Type; const typeExtras: Record = {}; if (typeSchema.type === 'String') { - typeExtras.maxLength = typeSchema['Max length']; + typeExtras.maxLength = typeSchema['Maximum length']; } else if (typeSchema.type === 'Integer' || typeSchema.type === 'Number') { typeExtras.minimum = typeSchema.Minimum; typeExtras.maximum = typeSchema.Maximum; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/FormComponents.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/FormComponents.tsx index 2716bfbfa81..250c896cfea 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/FormComponents.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/FormComponents.tsx @@ -3,7 +3,7 @@ import {Select} from '@wandb/weave/components/Form/Select'; import {TextField} from '@wandb/weave/components/Form/TextField'; import React from 'react'; -export const GAP_BETWEEN_ITEMS_PX = 10; +export const GAP_BETWEEN_ITEMS_PX = 16; export const GAP_BETWEEN_LABEL_AND_FIELD_PX = 10; type AutocompleteWithLabelType