From a06b6ace52988bff45012d78a29c9adbbeb32249 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Fri, 6 Dec 2024 08:57:21 -0800 Subject: [PATCH 01/52] chore(weave): Improve large object viewer renders (#3102) * init * init * Larger refactor * Larger refactor * Larger refactor * Larger refactor * Larger refactor --- .../Browse3/pages/CallPage/CallDetails.tsx | 41 ++----- .../Browse3/pages/CallPage/ObjectViewer.tsx | 81 ++++++++----- .../pages/CallPage/ObjectViewerSection.tsx | 4 +- .../traceServerCachingClient.ts | 106 ++++++++++++++++++ .../wfReactInterface/traceServerClient.ts | 4 +- .../typeViews/CustomWeaveTypeDispatcher.tsx | 8 ++ .../PIL.Image.Image/PILImageImage.tsx | 5 +- 7 files changed, 183 insertions(+), 66 deletions(-) create mode 100644 weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerCachingClient.ts diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallDetails.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallDetails.tsx index 8a9c011b5de..26be9f0c5d0 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallDetails.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallDetails.tsx @@ -1,4 +1,3 @@ -import {Typography} from '@mui/material'; import Box from '@mui/material/Box'; import _ from 'lodash'; import React, {FC, useContext, useMemo} from 'react'; @@ -10,9 +9,7 @@ import {Button} from '../../../../../Button'; import {useWeaveflowRouteContext, WeaveflowPeekContext} from '../../context'; import {CustomWeaveTypeProjectContext} from '../../typeViews/CustomWeaveTypeDispatcher'; import {CallsTable} from '../CallsPage/CallsTable'; -import {KeyValueTable} from '../common/KeyValueTable'; -import {CallLink, opNiceName} from '../common/Links'; -import {CenteredAnimatedLoader} from '../common/Loader'; +import {CallLink} from '../common/Links'; import {useWFHooks} from '../wfReactInterface/context'; import {CallSchema} from '../wfReactInterface/wfDataModelHooksInterface'; import {ButtonOverlay} from './ButtonOverlay'; @@ -20,6 +17,8 @@ import {ExceptionDetails, getExceptionInfo} from './Exceptions'; import {ObjectViewerSection} from './ObjectViewerSection'; import {OpVersionText} from './OpVersionText'; +const HEADER_HEIGHT_BUFFER = 60; + const Heading = styled.div` color: ${MOON_800}; font-weight: 600; @@ -104,7 +103,7 @@ export const CallDetails: FC<{ columns ); - const {singularChildCalls, multipleChildCallOpRefs} = useMemo( + const {multipleChildCallOpRefs} = useMemo( () => callGrouping(!childCalls.loading ? childCalls.result ?? [] : []), [childCalls.loading, childCalls.result] ); @@ -133,6 +132,7 @@ export const CallDetails: FC<{ 0 ? HEADER_HEIGHT_BUFFER : 0 + }px)`, p: 2, }}> {'traceback' in excInfo ? ( @@ -201,7 +204,6 @@ export const CallDetails: FC<{ sx={{ flex: '0 0 auto', height: '500px', - maxHeight: '95%', p: 2, display: 'flex', flexDirection: 'column', @@ -240,33 +242,6 @@ export const CallDetails: FC<{ ); })} - {childCalls.loading && } - {/* Disabling display of singular children while we decide if we want them here. */} - {false && singularChildCalls.length > 0 && ( - - {multipleChildCallOpRefs.length === 0 ? ( - Child calls - ) : ( - Singular child calls - )} - {singularChildCalls.map(c => ( - - - - ))} - - )} ); diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ObjectViewer.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ObjectViewer.tsx index d741db64771..e8f39ad2880 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ObjectViewer.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ObjectViewer.tsx @@ -23,7 +23,11 @@ import {LoadingDots} from '../../../../../LoadingDots'; import {Browse2OpDefCode} from '../../../Browse2/Browse2OpDefCode'; import {isWeaveRef} from '../../filters/common'; import {StyledDataGrid} from '../../StyledDataGrid'; -import {isCustomWeaveTypePayload} from '../../typeViews/customWeaveType.types'; +import { + CustomWeaveTypePayload, + isCustomWeaveTypePayload, +} from '../../typeViews/customWeaveType.types'; +import {getCustomWeaveTypePreferredRowHeight} from '../../typeViews/CustomWeaveTypeDispatcher'; import { LIST_INDEX_EDGE_NAME, OBJECT_ATTR_EDGE_NAME, @@ -48,6 +52,10 @@ import { } from './traverse'; import {ValueView} from './ValueView'; +const DEFAULT_ROW_HEIGHT = 38; +const CODE_ROW_HEIGHT = 350; +const TABLE_ROW_HEIGHT = 450; + type Data = Record | any[]; type ObjectViewerProps = { @@ -440,6 +448,47 @@ export const ObjectViewer = ({ }); }, [apiRef, expandedIds, updateRowExpand]); + // Per https://mui.com/x/react-data-grid/row-height/#dynamic-row-height, always + // memoize the getRowHeight function. + const getRowHeight = useCallback((params: GridRowHeightParams) => { + const isNonRefString = + params.model.valueType === 'string' && !isWeaveRef(params.model.value); + const isArray = params.model.valueType === 'array'; + const isTableRef = + isWeaveRef(params.model.value) && + (parseRefMaybe(params.model.value) as any).weaveKind === 'table'; + const {isCode} = params.model; + const isCustomWeaveType = isCustomWeaveTypePayload(params.model.value); + if (!params.model.isLeaf) { + // This is a group header, so we want to use the default height + return DEFAULT_ROW_HEIGHT; + } else if (isNonRefString) { + // This is the only special case where we will allow for dynamic height. + // Since strings have special renders that take up different amounts of + // space, we will allow for dynamic height. + return 'auto'; + } else if (isCustomWeaveType) { + const type = (params.model.value as CustomWeaveTypePayload).weave_type + .type; + const preferredRowHeight = getCustomWeaveTypePreferredRowHeight(type); + if (preferredRowHeight) { + return preferredRowHeight; + } + return DEFAULT_ROW_HEIGHT; + } else if ((isArray && USE_TABLE_FOR_ARRAYS) || isTableRef) { + // Perfectly enough space for 1 page of data rows + return TABLE_ROW_HEIGHT; + } else if (isCode) { + // Probably will get negative feedback here since code that is < 20 lines + // will have some whitespace below the code. However, we absolutely need + // to have static height for all cells else the MUI data grid will jump around + // when cleaning up virtual rows. + return CODE_ROW_HEIGHT; + } else { + return DEFAULT_ROW_HEIGHT; + } + }, []); + // Finally, we memoize the inner data grid component. This is important to // reduce the number of re-renders when the data changes. const inner = useMemo(() => { @@ -473,31 +522,9 @@ export const ObjectViewer = ({ isGroupExpandedByDefault={node => { return expandedIds.includes(node.id); }} - autoHeight columnHeaderHeight={38} - getRowHeight={(params: GridRowHeightParams) => { - const isNonRefString = - params.model.valueType === 'string' && - !isWeaveRef(params.model.value); - const isArray = params.model.valueType === 'array'; - const isTableRef = - isWeaveRef(params.model.value) && - (parseRefMaybe(params.model.value) as any).weaveKind === 'table'; - const {isCode} = params.model; - const isCustomWeaveType = isCustomWeaveTypePayload( - params.model.value - ); - if ( - isNonRefString || - (isArray && USE_TABLE_FOR_ARRAYS) || - isTableRef || - isCode || - isCustomWeaveType - ) { - return 'auto'; - } - return 38; - }} + rowHeight={DEFAULT_ROW_HEIGHT} + getRowHeight={getRowHeight} hideFooter rowSelection={false} groupingColDef={groupingColDef} @@ -517,10 +544,10 @@ export const ObjectViewer = ({ }} /> ); - }, [apiRef, rows, columns, groupingColDef, expandedIds]); + }, [apiRef, rows, columns, getRowHeight, groupingColDef, expandedIds]); // Return the inner data grid wrapped in a div with overflow hidden. - return
{inner}
; + return
{inner}
; }; // Helper function to build the base ref for a given path. This function is used diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ObjectViewerSection.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ObjectViewerSection.tsx index d95d9a841de..d02a2e881ca 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ObjectViewerSection.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ObjectViewerSection.tsx @@ -185,7 +185,7 @@ const ObjectViewerSectionNonEmpty = ({ }, [data, isExpanded]); return ( - <> + {title} diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/PlaygroundChat.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/PlaygroundChat.tsx index b5f675633fe..94cd3c17644 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/PlaygroundChat.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/PlaygroundChat.tsx @@ -115,11 +115,15 @@ export const PlaygroundChat = ({ }}> -
+
{state.traceCall && ( = ({ width: '100%', display: 'flex', justifyContent: 'space-between', - paddingBottom: '8px', }}> - - {props.headerContent} - + {props.headerContent && ( + + {props.headerContent} + + )} {(!props.hideTabsIfSingle || props.tabs.length > 1) && ( Date: Mon, 9 Dec 2024 12:44:29 -0500 Subject: [PATCH 14/52] chore(ui): Update metadata header styling (#3049) * Object metadata to Tailwind / refactor * Updated versions link and styling * Updated OpsVersionLink * Updated CallsLink to include next icon * Lint * Better flow for column sizes * Lint * No more colons --- .../Home/Browse3/pages/ObjectVersionPage.tsx | 83 ++++++++-------- .../Home/Browse3/pages/OpVersionPage.tsx | 98 ++++++++++++------- .../Home/Browse3/pages/common/Links.tsx | 22 +++-- 3 files changed, 119 insertions(+), 84 deletions(-) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionPage.tsx index 6108ed5407e..085587a64d8 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionPage.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionPage.tsx @@ -211,47 +211,50 @@ const ObjectVersionPageInner: React.FC<{ } headerContent={ - - {objectName}{' '} - {objectVersions.loading ? ( - - ) : ( - <> - [ - +
+
+

Name

+
+ +
+ {objectName} + {objectVersions.loading ? ( + + ) : ( + + ({objectVersionCount} version + {objectVersionCount !== 1 ? 's' : ''}) + + )} + - ] - - )} - - ), - Version: <>{objectVersionIndex}, - ...(refExtra - ? { - Subpath: refExtra, - } - : {}), - // 'Type Version': ( - // - // ), - }} - /> +
+
+
+
+
+

Version

+

{objectVersionIndex}

+
+ {refExtra && ( +
+

Subpath

+

{refExtra}

+
+ )} +
+ } // menuItems={[ // { diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/OpVersionPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/OpVersionPage.tsx index 1a6e4afc577..36f4e44afc5 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/OpVersionPage.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/OpVersionPage.tsx @@ -1,5 +1,6 @@ import React, {useMemo} from 'react'; +import {Icon} from '../../../../Icon'; import {LoadingDots} from '../../../../LoadingDots'; import {Tailwind} from '../../../../Tailwind'; import {NotFoundPanel} from '../NotFoundPanel'; @@ -13,7 +14,6 @@ import { import {CenteredAnimatedLoader} from './common/Loader'; import { ScrollableTabContent, - SimpleKeyValueTable, SimplePageLayoutWithHeader, } from './common/SimplePageLayout'; import {TabUseOp} from './TabUseOp'; @@ -75,49 +75,71 @@ const OpVersionPageInner: React.FC<{ - {opId}{' '} - {opVersions.loading ? ( - - ) : ( - <> - [ - - ] - - )} - - ), - Version: <>{versionIndex}, - Calls: - !callsStats.loading || opVersionCallCount > 0 ? ( - +
+
+

Name

+
+ + variant="secondary"> +
+ {opId} + {opVersions.loading ? ( + + ) : ( + + ({opVersionCount} version + {opVersionCount !== 1 ? 's' : ''}) + + )} + +
+
+
+
+
+

Version

+

{versionIndex}

+
+
+

Calls:

+ {!callsStats.loading || opVersionCallCount > 0 ? ( +
+ + +
) : ( - <> - ), - }} - /> +

-

+ )} +
+
+ } tabs={[ { diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/Links.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/Links.tsx index be78bb51367..4060735cc67 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/Links.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/Links.tsx @@ -429,6 +429,7 @@ export const ObjectVersionsLink: React.FC<{ filter?: WFHighLevelObjectVersionFilter; neverPeek?: boolean; variant?: LinkVariant; + children?: React.ReactNode; }> = props => { const {peekingRouter, baseRouter} = useWeaveflowRouteContext(); const router = props.neverPeek ? baseRouter : peekingRouter; @@ -440,9 +441,13 @@ export const ObjectVersionsLink: React.FC<{ props.project, props.filter )}> - {props.versionCount} - {props.countIsLimited ? '+' : ''} version - {props.versionCount !== 1 ? 's' : ''} + {props.children ?? ( + <> + {props.versionCount} + {props.countIsLimited ? '+' : ''} version + {props.versionCount !== 1 ? 's' : ''} + + )} ); }; @@ -455,6 +460,7 @@ export const OpVersionsLink: React.FC<{ filter?: WFHighLevelOpVersionFilter; neverPeek?: boolean; variant?: LinkVariant; + children?: React.ReactNode; }> = props => { const {peekingRouter, baseRouter} = useWeaveflowRouteContext(); const router = props.neverPeek ? baseRouter : peekingRouter; @@ -462,9 +468,13 @@ export const OpVersionsLink: React.FC<{ - {props.versionCount} - {props.countIsLimited ? '+' : ''} version - {props.versionCount !== 1 ? 's' : ''} + {props.children ?? ( + <> + {props.versionCount} + {props.countIsLimited ? '+' : ''} version + {props.versionCount !== 1 ? 's' : ''} + + )} ); }; From 7320ca9a1fcc51808103403d61923e00361a02d9 Mon Sep 17 00:00:00 2001 From: Griffin Tarpenning Date: Mon, 9 Dec 2024 10:35:15 -0800 Subject: [PATCH 15/52] chore(ui): disable refresh while refreshing (#3179) --- .../Home/Browse3/pages/CallsPage/CallsTable.tsx | 5 ++++- .../Home/Browse3/pages/CallsPage/CallsTableButtons.tsx | 4 +++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/CallsTable.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/CallsTable.tsx index 18130d64341..3632664aaf0 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/CallsTable.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/CallsTable.tsx @@ -744,7 +744,10 @@ export const CallsTable: FC<{ }} filterListItems={ - calls.refetch()} /> + calls.refetch()} + disabled={callsLoading} + /> {!hideOpSelector && ( void; -}> = ({onClick}) => { + disabled?: boolean; +}> = ({onClick, disabled}) => { return ( From 0cfe058b709b77c96b901a0975f113e4d42d85a3 Mon Sep 17 00:00:00 2001 From: Connie Lee Date: Mon, 9 Dec 2024 11:39:23 -0800 Subject: [PATCH 16/52] style(app): Update tailwind shadow configs (#3180) --- weave-js/tailwind.config.cjs | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/weave-js/tailwind.config.cjs b/weave-js/tailwind.config.cjs index b234cc3cf1a..d8a580c928c 100644 --- a/weave-js/tailwind.config.cjs +++ b/weave-js/tailwind.config.cjs @@ -13,8 +13,13 @@ module.exports = { */ boxShadow: { none: 'none', - md: '0px 12px 24px 0px #15181F29', - lg: '0px 24px 48px 0px #15181F29', + flat: '0px 4px 8px 0px #0D0F120a', // oblivion 4% + medium: '0px 12px 24px 0px #0D0F1229', // oblivion 16% + deep: '0px 24px 48px 0px #0D0F123d', // oblivion 24% + + // deprecated shadow configs + md: '0px 12px 24px 0px #15181F29', // use shadow-medium instead + lg: '0px 24px 48px 0px #15181F29', // use shadow-deep instead }, spacing: { 0: '0rem', @@ -189,17 +194,17 @@ module.exports = { }, extend: { animation: { - 'wave': 'wave 3s linear infinite' + wave: 'wave 3s linear infinite', }, keyframes: { - "wave": { - "0%, 30%, 100%": { - transform: "initial" + wave: { + '0%, 30%, 100%': { + transform: 'initial', }, - "15%": { - transform: "translateY(-10px)" - } - } + '15%': { + transform: 'translateY(-10px)', + }, + }, }, opacity: { 35: '.35', @@ -221,6 +226,6 @@ module.exports = { in their parent hierarchy */ important: '.tw-style', experimental: { - optimizeUniversalDefaults: true + optimizeUniversalDefaults: true, }, }; From e9b73f7c44479cc783ddc12d59518b5e6e40f1b4 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Mon, 9 Dec 2024 12:13:06 -0800 Subject: [PATCH 17/52] feat(weave): Add AzureOpenAI support for Scorers (#3171) * init * init --- weave/scorers/llm_utils.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/weave/scorers/llm_utils.py b/weave/scorers/llm_utils.py index 8e2672b8ee7..68ae2ccb366 100644 --- a/weave/scorers/llm_utils.py +++ b/weave/scorers/llm_utils.py @@ -23,10 +23,17 @@ from google.generativeai import GenerativeModel from instructor.patch import InstructorChatCompletionCreate from mistralai import Mistral - from openai import AsyncOpenAI, OpenAI + from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI _LLM_CLIENTS = Union[ - OpenAI, AsyncOpenAI, Anthropic, AsyncAnthropic, Mistral, GenerativeModel + OpenAI, + AsyncOpenAI, + AzureOpenAI, + AsyncAzureOpenAI, + Anthropic, + AsyncAnthropic, + Mistral, + GenerativeModel, ] else: _LLM_CLIENTS = object @@ -34,6 +41,8 @@ _LLM_CLIENTS_NAMES = ( "OpenAI", "AsyncOpenAI", + "AzureOpenAI", + "AsyncAzureOpenAI", "Anthropic", "AsyncAnthropic", "Mistral", From 2351ecf4f867a7834262f9b68ed026ba77723b24 Mon Sep 17 00:00:00 2001 From: Jamie Rasmussen <112953339+jamie-rasmussen@users.noreply.github.com> Date: Mon, 9 Dec 2024 15:15:14 -0600 Subject: [PATCH 18/52] chore(weave): add sun and moon icons (#3182) --- weave-js/src/assets/icons/icon-moon.svg | 3 +++ weave-js/src/assets/icons/icon-not-visible.svg | 2 +- weave-js/src/assets/icons/icon-pin-to-right.svg | 2 +- weave-js/src/assets/icons/icon-sun.svg | 11 +++++++++++ weave-js/src/components/Icon/Icon.tsx | 10 ++++++++++ weave-js/src/components/Icon/index.ts | 2 ++ weave-js/src/components/Icon/types.ts | 2 ++ 7 files changed, 30 insertions(+), 2 deletions(-) create mode 100644 weave-js/src/assets/icons/icon-moon.svg create mode 100644 weave-js/src/assets/icons/icon-sun.svg diff --git a/weave-js/src/assets/icons/icon-moon.svg b/weave-js/src/assets/icons/icon-moon.svg new file mode 100644 index 00000000000..e448eab96c3 --- /dev/null +++ b/weave-js/src/assets/icons/icon-moon.svg @@ -0,0 +1,3 @@ + + + diff --git a/weave-js/src/assets/icons/icon-not-visible.svg b/weave-js/src/assets/icons/icon-not-visible.svg index 766810c7811..b2782d987b9 100644 --- a/weave-js/src/assets/icons/icon-not-visible.svg +++ b/weave-js/src/assets/icons/icon-not-visible.svg @@ -2,4 +2,4 @@ - \ No newline at end of file + diff --git a/weave-js/src/assets/icons/icon-pin-to-right.svg b/weave-js/src/assets/icons/icon-pin-to-right.svg index 1ae05ea52ae..46a9c0bf114 100644 --- a/weave-js/src/assets/icons/icon-pin-to-right.svg +++ b/weave-js/src/assets/icons/icon-pin-to-right.svg @@ -2,4 +2,4 @@ - \ No newline at end of file + diff --git a/weave-js/src/assets/icons/icon-sun.svg b/weave-js/src/assets/icons/icon-sun.svg new file mode 100644 index 00000000000..bb0c57891b0 --- /dev/null +++ b/weave-js/src/assets/icons/icon-sun.svg @@ -0,0 +1,11 @@ + + + + + + + + + + + diff --git a/weave-js/src/components/Icon/Icon.tsx b/weave-js/src/components/Icon/Icon.tsx index 948c8a456b4..9c020b1bab9 100644 --- a/weave-js/src/components/Icon/Icon.tsx +++ b/weave-js/src/components/Icon/Icon.tsx @@ -139,6 +139,7 @@ import {ReactComponent as ImportMinimizeMode} from '../../assets/icons/icon-mini import {ReactComponent as ImportModel} from '../../assets/icons/icon-model.svg'; import {ReactComponent as ImportModelOnDark} from '../../assets/icons/icon-model-on-dark.svg'; import {ReactComponent as ImportMolecule} from '../../assets/icons/icon-molecule.svg'; +import {ReactComponent as ImportMoon} from '../../assets/icons/icon-moon.svg'; import {ReactComponent as ImportMusicAudio} from '../../assets/icons/icon-music-audio.svg'; import {ReactComponent as ImportNewSectionAbove} from '../../assets/icons/icon-new-section-above.svg'; import {ReactComponent as ImportNewSectionBelow} from '../../assets/icons/icon-new-section-below.svg'; @@ -216,6 +217,7 @@ import {ReactComponent as ImportStar} from '../../assets/icons/icon-star.svg'; import {ReactComponent as ImportStarFilled} from '../../assets/icons/icon-star-filled.svg'; import {ReactComponent as ImportStop} from '../../assets/icons/icon-stop.svg'; import {ReactComponent as ImportStopped} from '../../assets/icons/icon-stopped.svg'; +import {ReactComponent as ImportSun} from '../../assets/icons/icon-sun.svg'; import {ReactComponent as ImportSwap} from '../../assets/icons/icon-swap.svg'; import {ReactComponent as ImportSweepBayes} from '../../assets/icons/icon-sweep-bayes.svg'; import {ReactComponent as ImportSweepGrid} from '../../assets/icons/icon-sweep-grid.svg'; @@ -695,6 +697,9 @@ export const IconModelOnDark = (props: SVGIconProps) => ( export const IconMolecule = (props: SVGIconProps) => ( ); +export const IconMoon = (props: SVGIconProps) => ( + +); export const IconMusicAudio = (props: SVGIconProps) => ( ); @@ -926,6 +931,9 @@ export const IconStop = (props: SVGIconProps) => ( export const IconStopped = (props: SVGIconProps) => ( ); +export const IconSun = (props: SVGIconProps) => ( + +); export const IconSwap = (props: SVGIconProps) => ( ); @@ -1211,6 +1219,7 @@ const ICON_NAME_TO_ICON: Record = { model: IconModel, 'model-on-dark': IconModelOnDark, molecule: IconMolecule, + moon: IconMoon, 'music-audio': IconMusicAudio, 'new-section-above': IconNewSectionAbove, 'new-section-below': IconNewSectionBelow, @@ -1288,6 +1297,7 @@ const ICON_NAME_TO_ICON: Record = { 'star-filled': IconStarFilled, stop: IconStop, stopped: IconStopped, + sun: IconSun, swap: IconSwap, 'sweep-bayes': IconSweepBayes, 'sweep-grid': IconSweepGrid, diff --git a/weave-js/src/components/Icon/index.ts b/weave-js/src/components/Icon/index.ts index 81839a71019..85ea5332649 100644 --- a/weave-js/src/components/Icon/index.ts +++ b/weave-js/src/components/Icon/index.ts @@ -139,6 +139,7 @@ export { IconModel, IconModelOnDark, IconMolecule, + IconMoon, IconMusicAudio, IconNewSectionAbove, IconNewSectionBelow, @@ -216,6 +217,7 @@ export { IconStarFilled, IconStop, IconStopped, + IconSun, IconSwap, IconSweepBayes, IconSweepGrid, diff --git a/weave-js/src/components/Icon/types.ts b/weave-js/src/components/Icon/types.ts index 55f46c52833..87a1207bc85 100644 --- a/weave-js/src/components/Icon/types.ts +++ b/weave-js/src/components/Icon/types.ts @@ -138,6 +138,7 @@ export const IconNames = { Model: 'model', ModelOnDark: 'model-on-dark', Molecule: 'molecule', + Moon: 'moon', MusicAudio: 'music-audio', NewSectionAbove: 'new-section-above', NewSectionBelow: 'new-section-below', @@ -215,6 +216,7 @@ export const IconNames = { StarFilled: 'star-filled', Stop: 'stop', Stopped: 'stopped', + Sun: 'sun', Swap: 'swap', SweepBayes: 'sweep-bayes', SweepGrid: 'sweep-grid', From 8d0f3f7a6d84108fe52d335251d4080078dff0f0 Mon Sep 17 00:00:00 2001 From: Jamie Rasmussen <112953339+jamie-rasmussen@users.noreply.github.com> Date: Mon, 9 Dec 2024 16:08:01 -0600 Subject: [PATCH 19/52] chore(ui): fix typo (#3183) --- .../Home/Browse3/pages/CallPage/CallPage.tsx | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx index 3e4a74a4885..484b038c193 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx @@ -312,7 +312,7 @@ const CallPageInnerVertical: FC<{ const {rowIdsConfigured} = useContext(TableRowSelectionContext); const {isPeeking} = useContext(WeaveflowPeekContext); - const showPaginationContols = isPeeking && rowIdsConfigured; + const showPaginationControls = isPeeking && rowIdsConfigured; const callTabs = useCallTabs(currentCall); @@ -330,10 +330,10 @@ const CallPageInnerVertical: FC<{ justifyContent: 'space-between', alignItems: 'center', }}> - {showPaginationContols && ( + {showPaginationControls && ( )} - + +
+ +
+
+ + + + ); +}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/PlotlyBarPlot.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/PlotlyBarPlot.tsx index 9706ac09567..7942aea195e 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/PlotlyBarPlot.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/PlotlyBarPlot.tsx @@ -2,50 +2,54 @@ import * as Plotly from 'plotly.js'; import React, {useEffect, useMemo, useRef} from 'react'; import {PLOT_GRID_COLOR} from '../../ecpConstants'; -import {RadarPlotData} from './PlotlyRadarPlot'; export const PlotlyBarPlot: React.FC<{ height: number; - data: RadarPlotData; + yRange: [number, number]; + plotlyData: Plotly.Data; }> = props => { const divRef = useRef(null); - const plotlyData: Plotly.Data[] = useMemo(() => { - return Object.keys(props.data).map((key, i) => { - const {metrics, name, color} = props.data[key]; - return { - type: 'bar', - y: Object.values(metrics), - x: Object.keys(metrics), - name, - marker: {color}, - }; - }); - }, [props.data]); - const plotlyLayout: Partial = useMemo(() => { return { - height: props.height - 40, + height: props.height - 30, showlegend: false, margin: { - l: 0, + l: 20, r: 0, b: 20, - t: 0, - pad: 0, + t: 26, }, + bargap: 0.1, xaxis: { automargin: true, fixedrange: true, gridcolor: PLOT_GRID_COLOR, linecolor: PLOT_GRID_COLOR, + showticklabels: false, }, yaxis: { fixedrange: true, + range: props.yRange, gridcolor: PLOT_GRID_COLOR, linecolor: PLOT_GRID_COLOR, + showticklabels: true, + tickfont: { + size: 10, + }, + }, + title: { + multiline: true, + text: props.plotlyData.name ?? '', + font: {size: 12}, + xref: 'paper', + x: 0.5, + y: 1, + yanchor: 'top', + pad: {t: 2}, }, }; - }, [props.height]); + }, [props.height, props.plotlyData, props.yRange]); + const plotlyConfig = useMemo(() => { return { displayModeBar: false, @@ -57,11 +61,11 @@ export const PlotlyBarPlot: React.FC<{ useEffect(() => { Plotly.newPlot( divRef.current as any, - plotlyData, + [props.plotlyData], plotlyLayout, plotlyConfig ); - }, [plotlyConfig, plotlyData, plotlyLayout]); + }, [plotlyConfig, props.plotlyData, plotlyLayout]); return
; }; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/PlotlyRadarPlot.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/PlotlyRadarPlot.tsx index d459d1354f1..47d0fa3f10c 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/PlotlyRadarPlot.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/PlotlyRadarPlot.tsx @@ -31,13 +31,13 @@ export const PlotlyRadarPlot: React.FC<{ }, [props.data]); const plotlyLayout: Partial = useMemo(() => { return { - height: props.height, + height: props.height - 40, showlegend: false, margin: { - l: 60, - r: 0, + l: 20, + r: 20, b: 30, - t: 30, + t: 20, pad: 0, }, polar: { diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/SummaryPlotsSection.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/SummaryPlotsSection.tsx index 5bfaa8fcb04..c23ffcce04d 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/SummaryPlotsSection.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/SummaryPlotsSection.tsx @@ -1,15 +1,6 @@ import {Box} from '@material-ui/core'; -import {Popover} from '@mui/material'; -import {Switch} from '@wandb/weave/components'; import {Button} from '@wandb/weave/components/Button'; -import { - DraggableGrow, - DraggableHandle, -} from '@wandb/weave/components/DraggablePopups'; -import {TextField} from '@wandb/weave/components/Form/TextField'; import {Tailwind} from '@wandb/weave/components/Tailwind'; -import {maybePluralize} from '@wandb/weave/core/util/string'; -import classNames from 'classnames'; import React, {useEffect, useMemo, useRef, useState} from 'react'; import {buildCompositeMetricsMap} from '../../compositeMetricsUtil'; @@ -27,6 +18,7 @@ import { resolveSummaryMetricValueForEvaluateCall, } from '../../ecpUtil'; import {HorizontalBox, VerticalBox} from '../../Layout'; +import {MetricsSelector} from './MetricsSelector'; import {PlotlyBarPlot} from './PlotlyBarPlot'; import {PlotlyRadarPlot, RadarPlotData} from './PlotlyRadarPlot'; @@ -36,15 +28,12 @@ import {PlotlyRadarPlot, RadarPlotData} from './PlotlyRadarPlot'; export const SummaryPlots: React.FC<{ state: EvaluationComparisonState; setSelectedMetrics: (newModel: Record) => void; -}> = props => { - const {radarData, allMetricNames} = useNormalizedPlotDataFromMetrics( - props.state - ); - const {selectedMetrics} = props.state; - const setSelectedMetrics = props.setSelectedMetrics; +}> = ({state, setSelectedMetrics}) => { + const {radarData, allMetricNames} = useNormalizedPlotDataFromMetrics(state); + const {selectedMetrics} = state; + // Initialize selectedMetrics if null useEffect(() => { - // If selectedMetrics is null, we should show all metrics if (selectedMetrics == null) { setSelectedMetrics( Object.fromEntries(Array.from(allMetricNames).map(m => [m, true])) @@ -52,10 +41,184 @@ export const SummaryPlots: React.FC<{ } }, [selectedMetrics, setSelectedMetrics, allMetricNames]); - // filter down the plotlyRadarData to only include the selected metrics, after - // computation, to allow quick addition/removal of metrics - const filteredPlotlyRadarData = useMemo(() => { - const filteredData: RadarPlotData = {}; + const filteredData = useFilteredData(radarData, selectedMetrics); + const normalizedRadarData = normalizeDataForRadarPlot(filteredData); + const barPlotData = useBarPlotData(filteredData); + + const { + containerRef, + isInitialRender, + plotsPerPage, + currentPage, + setCurrentPage, + } = useContainerDimensions(); + + const {plotsToShow, totalPlots, startIndex, endIndex, totalPages} = + usePaginatedPlots( + normalizedRadarData, + barPlotData, + plotsPerPage, + currentPage + ); + + // Render placeholder during initial render + if (isInitialRender) { + return
; + } + + return ( + + +
+ {plotsToShow} +
+ setCurrentPage(prev => Math.max(prev - 1, 0))} + onNextPage={() => + setCurrentPage(prev => Math.min(prev + 1, totalPages - 1)) + } + /> +
+ ); +}; + +const SectionHeader: React.FC<{ + selectedMetrics: Record | undefined; + setSelectedMetrics: (newModel: Record) => void; + allMetrics: string[]; +}> = ({selectedMetrics, setSelectedMetrics, allMetrics}) => ( + + + Summary Metrics + + +
+
Configure displayed metrics
+ +
+
+
+); + +const RadarPlotBox: React.FC<{data: RadarPlotData}> = ({data}) => ( + + + +); + +const BarPlotBox: React.FC<{ + plot: {plotlyData: Plotly.Data; yRange: [number, number]}; +}> = ({plot}) => ( + + + +); + +const PaginationControls: React.FC<{ + currentPage: number; + totalPages: number; + startIndex: number; + endIndex: number; + totalPlots: number; + onPrevPage: () => void; + onNextPage: () => void; +}> = ({ + currentPage, + totalPages, + startIndex, + endIndex, + totalPlots, + onPrevPage, + onNextPage, +}) => ( + + + +
+
+
+
+
+); + +const useFilteredData = ( + radarData: RadarPlotData, + selectedMetrics: Record | undefined +) => + useMemo(() => { + const data: RadarPlotData = {}; for (const [callId, metricBin] of Object.entries(radarData)) { const metrics: {[metric: string]: number} = {}; for (const [metric, value] of Object.entries(metricBin.metrics)) { @@ -64,253 +227,195 @@ export const SummaryPlots: React.FC<{ } } if (Object.keys(metrics).length > 0) { - filteredData[callId] = { + data[callId] = { metrics, name: metricBin.name, color: metricBin.color, }; } } - return filteredData; + return data; }, [radarData, selectedMetrics]); - return ( - - - - Summary Metrics - - -
-
Configure displayed metrics
- -
-
-
- - - - - - - - -
- ); -}; +function getMetricValuesFromRadarData(radarData: RadarPlotData): { + [metric: string]: number[]; +} { + const metricValues: {[metric: string]: number[]} = {}; + // Gather all values for each metric + Object.values(radarData).forEach(callData => { + Object.entries(callData.metrics).forEach(([metric, value]) => { + if (!metricValues[metric]) { + metricValues[metric] = []; + } + metricValues[metric].push(value); + }); + }); + return metricValues; +} -const MetricsSelector: React.FC<{ - setSelectedMetrics: (newModel: Record) => void; - selectedMetrics: Record | undefined; - allMetrics: string[]; -}> = ({setSelectedMetrics, selectedMetrics, allMetrics}) => { - const [search, setSearch] = useState(''); - - const ref = useRef(null); - const [anchorEl, setAnchorEl] = useState(null); - const onClick = (event: React.MouseEvent) => { - setAnchorEl(anchorEl ? null : ref.current); - setSearch(''); +function getMetricMinsFromRadarData(radarData: RadarPlotData): { + [metric: string]: number; +} { + const metricValues = getMetricValuesFromRadarData(radarData); + const metricMins: {[metric: string]: number} = {}; + Object.entries(metricValues).forEach(([metric, values]) => { + metricMins[metric] = Math.min(...values); + }); + return metricMins; +} + +function normalizeDataForRadarPlot(radarData: RadarPlotData): RadarPlotData { + const metricMins = getMetricMinsFromRadarData(radarData); + + const normalizedData: RadarPlotData = {}; + Object.entries(radarData).forEach(([callId, callData]) => { + normalizedData[callId] = { + name: callData.name, + color: callData.color, + metrics: {}, + }; + + Object.entries(callData.metrics).forEach(([metric, value]) => { + const min = metricMins[metric]; + // Only shift values if there are negative values + const normalizedValue = min < 0 ? value - min : value; + normalizedData[callId].metrics[metric] = normalizedValue; + }); + }); + + return normalizedData; +} + +const useBarPlotData = (filteredData: RadarPlotData) => + useMemo(() => { + const metrics: { + [metric: string]: { + callIds: string[]; + values: number[]; + name: string; + colors: string[]; + }; + } = {}; + + // Reorganize data by metric instead of by call + for (const [callId, metricBin] of Object.entries(filteredData)) { + for (const [metric, value] of Object.entries(metricBin.metrics)) { + if (!metrics[metric]) { + metrics[metric] = {callIds: [], values: [], name: metric, colors: []}; + } + metrics[metric].callIds.push(callId); + metrics[metric].values.push(value); + metrics[metric].colors.push(metricBin.color); + } + } + + // Convert metrics object to Plotly data format + return Object.entries(metrics).map(([metric, metricBin]) => { + const maxY = Math.max(...metricBin.values) * 1.1; + const minY = Math.min(...metricBin.values, 0); + const plotlyData: Plotly.Data = { + type: 'bar', + y: metricBin.values, + x: metricBin.callIds, + text: metricBin.values.map(value => value.toFixed(3)), + textposition: 'outside', + textfont: {size: 14, color: 'black'}, + name: metric, + marker: {color: metricBin.colors}, + }; + return {plotlyData, yRange: [minY, maxY] as [number, number]}; + }); + }, [filteredData]); + +const useContainerDimensions = () => { + const containerRef = useRef(null); + const [containerWidth, setContainerWidth] = useState(0); + const [isInitialRender, setIsInitialRender] = useState(true); + const [currentPage, setCurrentPage] = useState(0); + + useEffect(() => { + const updateWidth = () => { + if (containerRef.current) { + setContainerWidth(containerRef.current.offsetWidth); + } + }; + + updateWidth(); + setIsInitialRender(false); + + window.addEventListener('resize', updateWidth); + return () => window.removeEventListener('resize', updateWidth); + }, []); + + const plotsPerPage = useMemo(() => { + return Math.max(1, Math.floor(containerWidth / PLOT_HEIGHT)); + }, [containerWidth]); + + return { + containerRef, + isInitialRender, + plotsPerPage, + currentPage, + setCurrentPage, }; - const open = Boolean(anchorEl); - const id = open ? 'simple-popper' : undefined; +}; - const filteredCols = search - ? allMetrics.filter(col => col.toLowerCase().includes(search.toLowerCase())) - : allMetrics; +const usePaginatedPlots = ( + filteredData: RadarPlotData, + barPlotData: Array<{plotlyData: Plotly.Data; yRange: [number, number]}>, + plotsPerPage: number, + currentPage: number +) => { + const radarPlotWidth = 2; + const totalBarPlots = barPlotData.length; + const totalPlotWidth = radarPlotWidth + totalBarPlots; + const totalPages = Math.ceil(totalPlotWidth / plotsPerPage); - const shownMetrics = Object.values(selectedMetrics ?? {}).filter(Boolean); + const plotsToShow = useMemo(() => { + // First page always shows radar plot + if (currentPage === 0) { + const availableSpace = plotsPerPage - radarPlotWidth; + return [ + , + ...barPlotData + .slice(0, availableSpace) + .map((plot, index) => ( + + )), + ]; + } else { + // Subsequent pages show only bar plots + const startIdx = + (currentPage - 1) * plotsPerPage + (plotsPerPage - radarPlotWidth); + const endIdx = startIdx + plotsPerPage; + return barPlotData + .slice(startIdx, endIdx) + .map((plot, index) => ( + + )); + } + }, [currentPage, plotsPerPage, filteredData, barPlotData]); - const numHidden = allMetrics.length - shownMetrics.length; - const buttonSuffix = search ? `(${filteredCols.length})` : 'all'; + // Calculate pagination details + const totalPlots = barPlotData.length + 1; // +1 for the radar plot + const startIndex = + currentPage === 0 ? 1 : Math.min(plotsPerPage + 1, totalPlots); + const endIndex = + currentPage === 0 + ? Math.min(plotsToShow.length, totalPlots) + : Math.min(startIndex + plotsToShow.length - 1, totalPlots); - return ( - <> - - -
- -
-
- - - - ); + return {plotsToShow, totalPlots, startIndex, endIndex, totalPages}; }; -const normalizeValues = (values: Array): number[] => { +function normalizeValues(values: Array): number[] { // find the max value // find the power of 2 that is greater than the max value // divide all values by that power of 2 const maxVal = Math.max(...(values.filter(v => v !== undefined) as number[])); const maxPower = Math.ceil(Math.log2(maxVal)); return values.map(val => (val ? val / 2 ** maxPower : 0)); -}; +} const useNormalizedPlotDataFromMetrics = ( state: EvaluationComparisonState From a7170b4c1e6dbca9b279865b588d4edfee34468f Mon Sep 17 00:00:00 2001 From: Jamie Rasmussen <112953339+jamie-rasmussen@users.noreply.github.com> Date: Tue, 10 Dec 2024 11:00:13 -0600 Subject: [PATCH 21/52] feat(ui): Support Anthropic calls in Chat View (#3187) --- .../Home/Browse3/pages/ChatView/hooks.ts | 88 +++++++++++++++++++ 1 file changed, 88 insertions(+) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/hooks.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/hooks.ts index 38b5c820195..33ced58ec49 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/hooks.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/hooks.ts @@ -246,6 +246,63 @@ export const isTraceCallChatFormatGemini = (call: TraceCallSchema): boolean => { ); }; +export const isAnthropicContentBlock = (block: any): boolean => { + if (!_.isPlainObject(block)) { + return false; + } + // TODO: Are there other types? + if (block.type !== 'text') { + return false; + } + if (!hasStringProp(block, 'text')) { + return false; + } + return true; +}; + +export const isAnthropicCompletionFormat = (output: any): boolean => { + if (output !== null) { + // TODO: Could have additional checks here on things like usage + if ( + _.isPlainObject(output) && + output.type === 'message' && + output.role === 'assistant' && + hasStringProp(output, 'model') && + _.isArray(output.content) && + output.content.every((c: any) => isAnthropicContentBlock(c)) + ) { + return true; + } + return false; + } + return true; +}; + +type AnthropicContentBlock = { + type: 'text'; + text: string; +}; + +export const anthropicContentBlocksToChoices = ( + blocks: AnthropicContentBlock[], + stopReason: string +): Choice[] => { + const choices: Choice[] = []; + for (let i = 0; i < blocks.length; i++) { + const block = blocks[i]; + choices.push({ + index: i, + message: { + role: 'assistant', + content: block.text, + }, + // TODO: What is correct way to map this? + finish_reason: stopReason, + }); + } + return choices; +}; + export const isTraceCallChatFormatOpenAI = (call: TraceCallSchema): boolean => { if (!('messages' in call.inputs)) { return false; @@ -336,6 +393,19 @@ export const normalizeChatRequest = (request: any): ChatRequest => { ], }; } + // Anthropic has system message as a top-level request field + if (hasStringProp(request, 'system')) { + return { + ...request, + messages: [ + { + role: 'system', + content: request.system, + }, + ...request.messages, + ], + }; + } return request as ChatRequest; }; @@ -360,6 +430,24 @@ export const normalizeChatCompletion = ( }, }; } + if (isAnthropicCompletionFormat(completion)) { + return { + id: completion.id, + choices: anthropicContentBlocksToChoices( + completion.content, + completion.stop_reason + ), + created: 0, + model: completion.model, + system_fingerprint: '', + usage: { + prompt_tokens: completion.usage.input_tokens, + completion_tokens: completion.usage.output_tokens, + total_tokens: + completion.usage.input_tokens + completion.usage.output_tokens, + }, + }; + } return completion as ChatCompletion; }; From 935f68cbfdd55a2f65a476b201d2227167416097 Mon Sep 17 00:00:00 2001 From: Connie Lee Date: Tue, 10 Dec 2024 09:28:08 -0800 Subject: [PATCH 22/52] feat(ui): Add ExpandingPill component (#3184) --- weave-js/src/components/Tag/Pill.tsx | 39 ++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/weave-js/src/components/Tag/Pill.tsx b/weave-js/src/components/Tag/Pill.tsx index 6f734f9cbfc..ed958a0c543 100644 --- a/weave-js/src/components/Tag/Pill.tsx +++ b/weave-js/src/components/Tag/Pill.tsx @@ -59,3 +59,42 @@ export const IconOnlyPill: FC = ({ ); }; + +export type ExpandingPillProps = { + className?: string; + color?: TagColorName; + icon: IconName; + label: string; +}; +export const ExpandingPill = ({ + className, + color, + icon, + label, +}: ExpandingPillProps) => { + const classes = useTagClasses({color, isInteractive: true}); + return ( + +
+ + + {label} + +
+
+ ); +}; From a14edbe6f6c37f3e9094b50ffdd2adc5bd3ecb65 Mon Sep 17 00:00:00 2001 From: Jamie Rasmussen <112953339+jamie-rasmussen@users.noreply.github.com> Date: Tue, 10 Dec 2024 11:52:04 -0600 Subject: [PATCH 23/52] chore(weave): ui_url error gives a hint about the problem (#3190) --- weave/trace/weave_client.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/weave/trace/weave_client.py b/weave/trace/weave_client.py index 86ba7b8653e..0eca3fcbedb 100644 --- a/weave/trace/weave_client.py +++ b/weave/trace/weave_client.py @@ -239,7 +239,9 @@ def func_name(self) -> str: @property def feedback(self) -> RefFeedbackQuery: if not self.id: - raise ValueError("Can't get feedback for call without ID") + raise ValueError( + "Can't get feedback for call without ID, was `weave.init` called?" + ) if self._feedback is None: try: @@ -253,7 +255,9 @@ def feedback(self) -> RefFeedbackQuery: @property def ui_url(self) -> str: if not self.id: - raise ValueError("Can't get URL for call without ID") + raise ValueError( + "Can't get URL for call without ID, was `weave.init` called?" + ) try: entity, project = self.project_id.split("/") @@ -265,7 +269,9 @@ def ui_url(self) -> str: def ref(self) -> CallRef: entity, project = self.project_id.split("/") if not self.id: - raise ValueError("Can't get ref for call without ID") + raise ValueError( + "Can't get ref for call without ID, was `weave.init` called?" + ) return CallRef(entity, project, self.id) @@ -273,7 +279,9 @@ def ref(self) -> CallRef: def children(self) -> CallsIter: client = weave_client_context.require_weave_client() if not self.id: - raise ValueError("Can't get children of call without ID") + raise ValueError( + "Can't get children of call without ID, was `weave.init` called?" + ) client = weave_client_context.require_weave_client() return CallsIter( From c19fe6bf246c058a3cbaa11bd544eb2e2fbbd24c Mon Sep 17 00:00:00 2001 From: brianlund-wandb Date: Tue, 10 Dec 2024 10:36:51 -0800 Subject: [PATCH 24/52] fix(app): render bounding boxes on media panel without points (#3158) * WB-16623 coercing empty point to default, targeting camera on bounding box when point cloud empty * code style cleanup --- .../src/common/util/SdkPointCloudToBabylon.test.ts | 14 ++++++++++++++ weave-js/src/common/util/SdkPointCloudToBabylon.ts | 2 +- weave-js/src/common/util/render_babylon.ts | 13 +++++++++++-- 3 files changed, 26 insertions(+), 3 deletions(-) diff --git a/weave-js/src/common/util/SdkPointCloudToBabylon.test.ts b/weave-js/src/common/util/SdkPointCloudToBabylon.test.ts index 95c7639dc4a..df5eca57b46 100644 --- a/weave-js/src/common/util/SdkPointCloudToBabylon.test.ts +++ b/weave-js/src/common/util/SdkPointCloudToBabylon.test.ts @@ -4,6 +4,7 @@ import { DEFAULT_POINT_COLOR, getFilteringOptionsForPointCloud, getVertexCompatiblePositionsAndColors, + loadPointCloud, MAX_BOUNDING_BOX_LABELS_FOR_DISPLAY, MaxAlphaValue, } from './SdkPointCloudToBabylon'; @@ -174,3 +175,16 @@ describe('getFilteringOptionsForPointCloud', () => { expect(newClassIdToLabel.get(49)).toEqual('label49'); }); }); +describe('loadPointCloud', () => { + it('appropriate defaults set when loading point cloud from file', () => { + const fileContents = JSON.stringify({ + boxes: [], + points: [[]], + type: 'lidar/beta', + vectors: [], + }); + const babylonPointCloud = loadPointCloud(fileContents); + expect(babylonPointCloud.points).toHaveLength(1); + expect(babylonPointCloud.points[0].position).toEqual([0, 0, 0]); + }); +}); diff --git a/weave-js/src/common/util/SdkPointCloudToBabylon.ts b/weave-js/src/common/util/SdkPointCloudToBabylon.ts index 274e1676be4..d52682743ee 100644 --- a/weave-js/src/common/util/SdkPointCloudToBabylon.ts +++ b/weave-js/src/common/util/SdkPointCloudToBabylon.ts @@ -160,7 +160,7 @@ export const handlePoints = (object3D: Object3DScene): ScenePoint[] => { // Draw Points return truncatedPoints.map(point => { const [x, y, z, r, g, b] = point; - const position: Position = [x, y, z]; + const position: Position = [x ?? 0, y ?? 0, z ?? 0]; const category = r; if (r !== undefined && g !== undefined && b !== undefined) { diff --git a/weave-js/src/common/util/render_babylon.ts b/weave-js/src/common/util/render_babylon.ts index 10aee3f6c51..ebd213c2677 100644 --- a/weave-js/src/common/util/render_babylon.ts +++ b/weave-js/src/common/util/render_babylon.ts @@ -394,6 +394,15 @@ const pointCloudScene = ( // Apply vertexData to custom mesh vertexData.applyToMesh(pcMesh); + // A file without any points defined still includes a single, empty "point". + // In order to play nice with Babylon, we position this empty point at 0,0,0. + // Hence, a pointCloud with a single point at 0,0,0 is likely empty. + const isEmpty = + pointCloud.points.length === 1 && + pointCloud.points[0].position[0] === 0 && + pointCloud.points[0].position[1] === 0 && + pointCloud.points[0].position[2] === 0; + camera.parent = pcMesh; const pcMaterial = new Babylon.StandardMaterial('mat', scene); @@ -472,8 +481,8 @@ const pointCloudScene = ( new Vector3(edges.length * 2, edges.length * 2, edges.length * 2) ); - // If we are iterating over camera, target a box - if (index === meta?.cameraIndex) { + // If we are iterating over camera or the cloud is empty, target a box + if (index === meta?.cameraIndex || (index === 0 && isEmpty)) { camera.position = center.add(new Vector3(0, 0, 1000)); camera.target = center; camera.zoomOn([lines]); From 0966c9fb21d00fa6b3ed9019de2c4c471fc6ec30 Mon Sep 17 00:00:00 2001 From: Griffin Tarpenning Date: Tue, 10 Dec 2024 12:08:50 -0800 Subject: [PATCH 25/52] fix(ui): feedback header when empty (#3191) --- .../Browse3/feedback/StructuredFeedback/FeedbackSidebar.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/StructuredFeedback/FeedbackSidebar.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/StructuredFeedback/FeedbackSidebar.tsx index ef6bcbd69ff..0b3c9603fef 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/StructuredFeedback/FeedbackSidebar.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/StructuredFeedback/FeedbackSidebar.tsx @@ -96,7 +96,7 @@ export const FeedbackSidebar = ({
Feedback
-
+
{humanAnnotationSpecs.length > 0 ? ( <>
From 35a78d6e83006ea4462d48485651213a4afbd381 Mon Sep 17 00:00:00 2001 From: Griffin Tarpenning Date: Tue, 10 Dec 2024 16:20:17 -0800 Subject: [PATCH 26/52] chore(ui): make eval section header draggable (#2637) --- .../CompareEvaluationsPage.tsx | 14 +-- .../compareEvaluationsContext.tsx | 25 ++-- .../pages/CompareEvaluationsPage/ecpState.ts | 63 +++++++--- .../pages/CompareEvaluationsPage/ecpTypes.ts | 1 + .../ComparisonDefinitionSection.tsx | 110 ++++++++++++------ .../EvaluationDefinition.tsx | 83 +------------ .../DraggableSection/DraggableItem.tsx | 103 ++++++++++++++++ .../DraggableSection/DraggableSection.tsx | 34 ++++++ .../ExampleCompareSection.tsx | 2 +- .../ExampleFilterSection.tsx | 4 +- .../ScorecardSection/ScorecardSection.tsx | 3 +- 11 files changed, 280 insertions(+), 162 deletions(-) create mode 100644 weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/DraggableSection/DraggableItem.tsx create mode 100644 weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/DraggableSection/DraggableSection.tsx diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/CompareEvaluationsPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/CompareEvaluationsPage.tsx index b302b0262ae..478c4887546 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/CompareEvaluationsPage.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/CompareEvaluationsPage.tsx @@ -77,8 +77,6 @@ export const CompareEvaluationsPage: React.FC< export const CompareEvaluationsPageContent: React.FC< CompareEvaluationsPageProps > = props => { - const [baselineEvaluationCallId, setBaselineEvaluationCallId] = - React.useState(null); const [comparisonDimensions, setComparisonDimensions] = React.useState(null); @@ -104,14 +102,6 @@ export const CompareEvaluationsPageContent: React.FC< [comparisonDimensions] ); - React.useEffect(() => { - // Only update the baseline if we are switching evaluations, if there - // is more than 1, we are in the compare view and baseline is auto set - if (props.evaluationCallIds.length === 1) { - setBaselineEvaluationCallId(props.evaluationCallIds[0]); - } - }, [props.evaluationCallIds]); - if (props.evaluationCallIds.length === 0) { return
No evaluations to compare
; } @@ -120,13 +110,11 @@ export const CompareEvaluationsPageContent: React.FC< diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/compareEvaluationsContext.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/compareEvaluationsContext.tsx index b65658a890a..638565e8ad6 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/compareEvaluationsContext.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/compareEvaluationsContext.tsx @@ -10,9 +10,6 @@ import {ComparisonDimensionsType} from './ecpState'; const CompareEvaluationsContext = React.createContext<{ state: EvaluationComparisonState; - setBaselineEvaluationCallId: React.Dispatch< - React.SetStateAction - >; setComparisonDimensions: React.Dispatch< React.SetStateAction >; @@ -20,6 +17,7 @@ const CompareEvaluationsContext = React.createContext<{ setSelectedMetrics: (newModel: Record) => void; addEvaluationCall: (newCallId: string) => void; removeEvaluationCall: (callId: string) => void; + setEvaluationCallOrder: (newCallIdOrder: string[]) => void; } | null>(null); export const useCompareEvaluationsState = () => { @@ -33,34 +31,26 @@ export const useCompareEvaluationsState = () => { export const CompareEvaluationsProvider: React.FC<{ entity: string; project: string; + initialEvaluationCallIds: string[]; selectedMetrics: Record | null; setSelectedMetrics: (newModel: Record) => void; - initialEvaluationCallIds: string[]; onEvaluationCallIdsUpdate: (newEvaluationCallIds: string[]) => void; - setBaselineEvaluationCallId: React.Dispatch< - React.SetStateAction - >; setComparisonDimensions: React.Dispatch< React.SetStateAction >; setSelectedInputDigest: React.Dispatch>; - baselineEvaluationCallId?: string; comparisonDimensions?: ComparisonDimensionsType; selectedInputDigest?: string; }> = ({ entity, project, + initialEvaluationCallIds, selectedMetrics, setSelectedMetrics, - - initialEvaluationCallIds, onEvaluationCallIdsUpdate, - setBaselineEvaluationCallId, setComparisonDimensions, - setSelectedInputDigest, - baselineEvaluationCallId, comparisonDimensions, selectedInputDigest, children, @@ -77,7 +67,6 @@ export const CompareEvaluationsProvider: React.FC<{ entity, project, evaluationCallIds, - baselineEvaluationCallId, comparisonDimensions, selectedInputDigest, selectedMetrics ?? undefined @@ -89,7 +78,6 @@ export const CompareEvaluationsProvider: React.FC<{ } return { state: initialState.result, - setBaselineEvaluationCallId, setComparisonDimensions, setSelectedInputDigest, setSelectedMetrics, @@ -105,14 +93,17 @@ export const CompareEvaluationsProvider: React.FC<{ setEvaluationCallIds(newEvaluationCallIds); onEvaluationCallIdsUpdate(newEvaluationCallIds); }, + setEvaluationCallOrder: (newCallIdOrder: string[]) => { + setEvaluationCallIds(newCallIdOrder); + onEvaluationCallIdsUpdate(newCallIdOrder); + }, }; }, [ initialState.loading, initialState.result, + setEvaluationCallIds, evaluationCallIds, onEvaluationCallIdsUpdate, - setEvaluationCallIds, - setBaselineEvaluationCallId, setComparisonDimensions, setSelectedInputDigest, setSelectedMetrics, diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpState.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpState.ts index 35f95dbf14f..e5c1b03d60a 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpState.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpState.ts @@ -18,14 +18,14 @@ import {getMetricIds} from './ecpUtil'; export type EvaluationComparisonState = { // The normalized data for the evaluations data: EvaluationComparisonData; - // The evaluation call id of the baseline model - baselineEvaluationCallId: string; // The dimensions to compare & filter results comparisonDimensions?: ComparisonDimensionsType; // The current digest which is in view selectedInputDigest?: string; // The selected metrics to display selectedMetrics?: Record; + // Ordered call Ids + evaluationCallIdsOrdered: string[]; }; export type ComparisonDimensionsType = Array<{ @@ -43,12 +43,14 @@ export const useEvaluationComparisonState = ( entity: string, project: string, evaluationCallIds: string[], - baselineEvaluationCallId?: string, comparisonDimensions?: ComparisonDimensionsType, selectedInputDigest?: string, selectedMetrics?: Record ): Loadable => { - const data = useEvaluationComparisonData(entity, project, evaluationCallIds); + const orderedCallIds = useMemo(() => { + return getCallIdsOrderedForQuery(evaluationCallIds); + }, [evaluationCallIds]); + const data = useEvaluationComparisonData(entity, project, orderedCallIds); const value = useMemo(() => { if (data.result == null || data.loading) { @@ -92,42 +94,45 @@ export const useEvaluationComparisonState = ( loading: false, result: { data: data.result, - baselineEvaluationCallId: - baselineEvaluationCallId ?? evaluationCallIds[0], comparisonDimensions: newComparisonDimensions, selectedInputDigest, selectedMetrics, + evaluationCallIdsOrdered: evaluationCallIds, }, }; }, [ data.result, data.loading, - baselineEvaluationCallId, - evaluationCallIds, comparisonDimensions, selectedInputDigest, selectedMetrics, + evaluationCallIds, ]); return value; }; +export const getOrderedCallIds = (state: EvaluationComparisonState) => { + return Array.from(state.evaluationCallIdsOrdered); +}; + +export const getBaselineCallId = (state: EvaluationComparisonState) => { + return getOrderedCallIds(state)[0]; +}; + /** - * Should use this over keys of `state.data.evaluationCalls` because it ensures the baseline - * evaluation call is first. + * Sort call IDs to ensure consistent order for memoized query params */ -export const getOrderedCallIds = (state: EvaluationComparisonState) => { - const initial = Object.keys(state.data.evaluationCalls); - moveItemToFront(initial, state.baselineEvaluationCallId); - return initial; +const getCallIdsOrderedForQuery = (callIds: string[]) => { + return Array.from(callIds).sort(); }; /** * Should use this over keys of `state.data.models` because it ensures the baseline model is first. */ export const getOrderedModelRefs = (state: EvaluationComparisonState) => { - const baselineRef = - state.data.evaluationCalls[state.baselineEvaluationCallId].modelRef; + const baselineCallId = getBaselineCallId(state); + const baselineRef = state.data.evaluationCalls[baselineCallId].modelRef; const refs = Object.keys(state.data.models); // Make sure the baseline model is first moveItemToFront(refs, baselineRef); @@ -145,3 +150,29 @@ const moveItemToFront = (arr: T[], item: T): T[] => { } return arr; }; + +export const getOrderedEvalsWithNewBaseline = ( + evaluationCallIds: string[], + newBaselineCallId: string +) => { + return moveItemToFront(evaluationCallIds, newBaselineCallId); +}; + +export const swapEvaluationCalls = ( + evaluationCallIds: string[], + ndx1: number, + ndx2: number +): string[] => { + return swapArrayItems(evaluationCallIds, ndx1, ndx2); +}; + +const swapArrayItems = (arr: T[], ndx1: number, ndx2: number): T[] => { + if (ndx1 < 0 || ndx2 < 0 || ndx1 >= arr.length || ndx2 >= arr.length) { + throw new Error('Index out of bounds'); + } + const newArr = [...arr]; + const from = newArr[ndx1]; + newArr[ndx1] = newArr[ndx2]; + newArr[ndx2] = from; + return newArr; +}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpTypes.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpTypes.ts index 7454e0707b4..b4642fae240 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpTypes.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpTypes.ts @@ -18,6 +18,7 @@ export type EvaluationComparisonData = { }; // EvaluationCalls are the specific calls of an evaluation. + // The visual order of the evaluation calls is determined by the order of the keys. evaluationCalls: { [callId: string]: EvaluationCall; }; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ComparisonDefinitionSection/ComparisonDefinitionSection.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ComparisonDefinitionSection/ComparisonDefinitionSection.tsx index 66ce56c4cb0..2704a66cbea 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ComparisonDefinitionSection/ComparisonDefinitionSection.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ComparisonDefinitionSection/ComparisonDefinitionSection.tsx @@ -12,44 +12,81 @@ import { import {useCallsForQuery} from '../../../CallsPage/callsTableQuery'; import {useEvaluationsFilter} from '../../../CallsPage/evaluationsFilter'; import {Id} from '../../../common/Id'; +import {opNiceName} from '../../../common/Links'; import {useWFHooks} from '../../../wfReactInterface/context'; import { CallSchema, ObjectVersionKey, } from '../../../wfReactInterface/wfDataModelHooksInterface'; import {useCompareEvaluationsState} from '../../compareEvaluationsContext'; -import {STANDARD_PADDING} from '../../ecpConstants'; -import {getOrderedCallIds} from '../../ecpState'; -import {EvaluationComparisonState} from '../../ecpState'; +import { + EvaluationComparisonState, + getOrderedCallIds, + getOrderedEvalsWithNewBaseline, + swapEvaluationCalls, +} from '../../ecpState'; import {HorizontalBox} from '../../Layout'; -import {EvaluationDefinition, VerticalBar} from './EvaluationDefinition'; +import {ItemDef} from '../DraggableSection/DraggableItem'; +import {DraggableSection} from '../DraggableSection/DraggableSection'; +import {VerticalBar} from './EvaluationDefinition'; export const ComparisonDefinitionSection: React.FC<{ state: EvaluationComparisonState; }> = props => { - const evalCallIds = useMemo( - () => getOrderedCallIds(props.state), - [props.state] - ); + const {setEvaluationCallOrder, removeEvaluationCall} = + useCompareEvaluationsState(); + + const callIds = useMemo(() => { + return getOrderedCallIds(props.state); + }, [props.state]); + + const items: ItemDef[] = useMemo(() => { + return callIds.map(callId => ({ + key: 'evaluations', + value: callId, + label: props.state.data.evaluationCalls[callId]?.name ?? callId, + })); + }, [callIds, props.state.data.evaluationCalls]); + + const onSetBaseline = (value: string | null) => { + if (!value) { + return; + } + const newSortOrder = getOrderedEvalsWithNewBaseline(callIds, value); + setEvaluationCallOrder(newSortOrder); + }; + const onRemoveItem = (value: string) => removeEvaluationCall(value); + const onSortEnd = ({ + oldIndex, + newIndex, + }: { + oldIndex: number; + newIndex: number; + }) => { + const newSortOrder = swapEvaluationCalls(callIds, oldIndex, newIndex); + setEvaluationCallOrder(newSortOrder); + }; return ( - - {evalCallIds.map((key, ndx) => { - return ( - - - - ); - })} - - + +
+ + + + + + +
+
); }; @@ -81,7 +118,7 @@ const ModelRefLabel: React.FC<{modelRef: string}> = props => { const objectVersion = useObjectVersion(objVersionKey); return ( - {objectVersion.result?.objectId}:{objectVersion.result?.versionIndex} + {objectVersion.result?.objectId}:v{objectVersion.result?.versionIndex} ); }; @@ -119,10 +156,9 @@ const AddEvaluationButton: React.FC<{ const evalsNotComparing = useMemo(() => { return calls.result.filter( - call => - !Object.keys(props.state.data.evaluationCalls).includes(call.callId) + call => !getOrderedCallIds(props.state).includes(call.callId) ); - }, [calls.result, props.state.data.evaluationCalls]); + }, [calls.result, props.state]); const [menuOptions, setMenuOptions] = useState(evalsNotComparing); @@ -222,12 +258,18 @@ const AddEvaluationButton: React.FC<{ variant="ghost" size="small" className="pb-8 pt-8 font-['Source_Sans_Pro'] text-base font-normal text-moon-800" - onClick={() => { - addEvaluationCall(call.callId); - }}> + onClick={() => addEvaluationCall(call.callId)}> <> - {call.displayName ?? call.spanName} - + + {call.displayName ?? opNiceName(call.spanName)} + + + + diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ComparisonDefinitionSection/EvaluationDefinition.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ComparisonDefinitionSection/EvaluationDefinition.tsx index adc80d044f6..5dcf835e378 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ComparisonDefinitionSection/EvaluationDefinition.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ComparisonDefinitionSection/EvaluationDefinition.tsx @@ -1,8 +1,5 @@ import {Box} from '@material-ui/core'; import {Circle} from '@mui/icons-material'; -import {PopupDropdown} from '@wandb/weave/common/components/PopupDropdown'; -import {Button} from '@wandb/weave/components/Button'; -import {Pill} from '@wandb/weave/components/Tag'; import React, {useMemo} from 'react'; import { @@ -17,87 +14,17 @@ import {SmallRef} from '../../../../../Browse2/SmallRef'; import {CallLink, ObjectVersionLink} from '../../../common/Links'; import {useWFHooks} from '../../../wfReactInterface/context'; import {ObjectVersionKey} from '../../../wfReactInterface/wfDataModelHooksInterface'; -import {useCompareEvaluationsState} from '../../compareEvaluationsContext'; -import { - BOX_RADIUS, - CIRCLE_SIZE, - EVAL_DEF_HEIGHT, - STANDARD_BORDER, -} from '../../ecpConstants'; +import {CIRCLE_SIZE} from '../../ecpConstants'; import {EvaluationComparisonState} from '../../ecpState'; -import {HorizontalBox} from '../../Layout'; - -export const EvaluationDefinition: React.FC<{ - state: EvaluationComparisonState; - callId: string; -}> = props => { - const {removeEvaluationCall, setBaselineEvaluationCallId} = - useCompareEvaluationsState(); - - const menuOptions = useMemo(() => { - return [ - { - key: 'add-to-baseline', - content: 'Set as baseline', - onClick: () => { - setBaselineEvaluationCallId(props.callId); - }, - disabled: props.callId === props.state.baselineEvaluationCallId, - }, - { - key: 'remove', - content: 'Remove', - onClick: () => { - removeEvaluationCall(props.callId); - }, - disabled: Object.keys(props.state.data.evaluationCalls).length === 1, - }, - ]; - }, [ - props.callId, - props.state.baselineEvaluationCallId, - props.state.data.evaluationCalls, - removeEvaluationCall, - setBaselineEvaluationCallId, - ]); - - return ( - - - {props.callId === props.state.baselineEvaluationCallId && ( - - )} -
- - } - /> -
-
- ); -}; export const EvaluationCallLink: React.FC<{ callId: string; state: EvaluationComparisonState; }> = props => { - const evaluationCall = props.state.data.evaluationCalls[props.callId]; + const evaluationCall = props.state.data.evaluationCalls?.[props.callId]; + if (!evaluationCall) { + return null; + } const {entity, project} = props.state.data; return ( diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/DraggableSection/DraggableItem.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/DraggableSection/DraggableItem.tsx new file mode 100644 index 00000000000..1510c502b99 --- /dev/null +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/DraggableSection/DraggableItem.tsx @@ -0,0 +1,103 @@ +import {Button} from '@wandb/weave/components/Button'; +import * as DropdownMenu from '@wandb/weave/components/DropdownMenu'; +import {Icon} from '@wandb/weave/components/Icon'; +import {Pill} from '@wandb/weave/components/Tag/Pill'; +import {Tailwind} from '@wandb/weave/components/Tailwind'; +import classNames from 'classnames'; +import React, {useState} from 'react'; +import {SortableElement, SortableHandle} from 'react-sortable-hoc'; + +import {EvaluationComparisonState} from '../../ecpState'; +import {EvaluationCallLink} from '../ComparisonDefinitionSection/EvaluationDefinition'; + +export type ItemDef = { + key: string; + value: string; + label?: string; +}; + +type DraggableItemProps = { + state: EvaluationComparisonState; + item: ItemDef; + numItems: number; + idx: number; + onRemoveItem: (value: string) => void; + onSetBaseline: (value: string | null) => void; +}; + +export const DraggableItem = SortableElement( + ({ + state, + item, + numItems, + idx, + onRemoveItem, + onSetBaseline, + }: DraggableItemProps) => { + const isDeletable = numItems > 1; + const isBaseline = idx === 0; + const [isOpen, setIsOpen] = useState(false); + + const onMakeBaselinePropagated = (e: React.MouseEvent) => { + e.stopPropagation(); + onSetBaseline(item.value); + }; + + const onRemoveItemPropagated = (e: React.MouseEvent) => { + e.stopPropagation(); + onRemoveItem(item.value); + }; + + return ( + +
+ +
+ + {isBaseline && ( + + )} +
+ + +
+
+ ); + } +); + +const DragHandle = SortableHandle(() => ( +
+ +
+)); diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/DraggableSection/DraggableSection.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/DraggableSection/DraggableSection.tsx new file mode 100644 index 00000000000..23a03ceb5b3 --- /dev/null +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/DraggableSection/DraggableSection.tsx @@ -0,0 +1,34 @@ +import React from 'react'; +import {SortableContainer} from 'react-sortable-hoc'; + +import {EvaluationComparisonState} from '../../ecpState'; +import {DraggableItem} from './DraggableItem'; +import {ItemDef} from './DraggableItem'; + +type DraggableSectionProps = { + state: EvaluationComparisonState; + items: ItemDef[]; + onSetBaseline: (value: string | null) => void; + onRemoveItem: (value: string) => void; +}; + +export const DraggableSection = SortableContainer( + ({state, items, onSetBaseline, onRemoveItem}: DraggableSectionProps) => { + return ( +
+ {items.map((item, index) => ( + + ))} +
+ ); + } +); diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleCompareSection/ExampleCompareSection.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleCompareSection/ExampleCompareSection.tsx index 6041492b5c5..398f65ecd45 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleCompareSection/ExampleCompareSection.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleCompareSection/ExampleCompareSection.tsx @@ -149,7 +149,7 @@ const stickySidebarHeaderMixin: React.CSSProperties = { /** * This component will occupy the entire space provided by the parent container. - * It is intended to be used in teh CompareEvaluations page, as it depends on + * It is intended to be used in the CompareEvaluations page, as it depends on * the EvaluationComparisonState. However, in principle, it is a general purpose * model-output comparison tool. It allows the user to view inputs, then compare * model outputs and evaluation metrics across multiple trials. diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleFilterSection/ExampleFilterSection.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleFilterSection/ExampleFilterSection.tsx index b2c8773a7d1..1146c8ea960 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleFilterSection/ExampleFilterSection.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleFilterSection/ExampleFilterSection.tsx @@ -13,7 +13,7 @@ import { } from '../../compositeMetricsUtil'; import {PLOT_HEIGHT, STANDARD_PADDING} from '../../ecpConstants'; import {MAX_PLOT_DOT_SIZE, MIN_PLOT_DOT_SIZE} from '../../ecpConstants'; -import {EvaluationComparisonState} from '../../ecpState'; +import {EvaluationComparisonState, getBaselineCallId} from '../../ecpState'; import {metricDefinitionId} from '../../ecpUtil'; import { flattenedDimensionPath, @@ -103,7 +103,7 @@ const SingleDimensionFilter: React.FC<{ }, [props.state.data]); const {setComparisonDimensions} = useCompareEvaluationsState(); - const baselineCallId = props.state.baselineEvaluationCallId; + const baselineCallId = getBaselineCallId(props.state); const compareCallId = Object.keys(props.state.data.evaluationCalls).find( callId => callId !== baselineCallId )!; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ScorecardSection/ScorecardSection.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ScorecardSection/ScorecardSection.tsx index 4b319150fa8..303f640f47e 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ScorecardSection/ScorecardSection.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ScorecardSection/ScorecardSection.tsx @@ -32,6 +32,7 @@ import { } from '../../ecpConstants'; import { EvaluationComparisonState, + getBaselineCallId, getOrderedCallIds, getOrderedModelRefs, } from '../../ecpState'; @@ -414,7 +415,7 @@ export const ScorecardSection: React.FC<{ {evalCallIds.map((evalCallId, mNdx) => { const baseline = resolveSummaryMetricResult( - props.state.baselineEvaluationCallId, + getBaselineCallId(props.state), groupName, metricKey, compositeSummaryMetrics, From f967984c8bf49bf49e1e5a500b338381f2481e4e Mon Sep 17 00:00:00 2001 From: Erica Diaz <156136421+ericakdiaz@users.noreply.github.com> Date: Tue, 10 Dec 2024 16:23:39 -0800 Subject: [PATCH 27/52] chore(weave): Update ToggleButtonGroup to allow disabling individual options (#3194) --- weave-js/src/components/ToggleButtonGroup.tsx | 67 +++++++++++-------- 1 file changed, 38 insertions(+), 29 deletions(-) diff --git a/weave-js/src/components/ToggleButtonGroup.tsx b/weave-js/src/components/ToggleButtonGroup.tsx index 65b2c538975..eca93e95657 100644 --- a/weave-js/src/components/ToggleButtonGroup.tsx +++ b/weave-js/src/components/ToggleButtonGroup.tsx @@ -9,6 +9,7 @@ import {Tailwind} from './Tailwind'; export type ToggleOption = { value: string; icon?: IconName; + isDisabled?: boolean; }; export type ToggleButtonGroupProps = { @@ -37,7 +38,10 @@ export const ToggleButtonGroup = React.forwardRef< } const handleValueChange = (newValue: string) => { - if (newValue !== value) { + if ( + newValue !== value && + options.find(option => option.value === newValue)?.isDisabled !== true + ) { onValueChange(newValue); } }; @@ -49,34 +53,39 @@ export const ToggleButtonGroup = React.forwardRef< onValueChange={handleValueChange} className="flex gap-px" ref={ref}> - {options.map(({value: optionValue, icon}) => ( - - - - ))} + {options.map( + ({value: optionValue, icon, isDisabled: optionIsDisabled}) => ( + + + + ) + )} ); From 4e1e2c9a6da47ceff2fbd113053d959b43b4ce6f Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Tue, 10 Dec 2024 16:30:13 -0800 Subject: [PATCH 28/52] chore(weave): Fix predict detection heuristic for evals (#3196) * init * lint --- .../tsDataModelHooksEvaluationComparison.ts | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/tsDataModelHooksEvaluationComparison.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/tsDataModelHooksEvaluationComparison.ts index 8418f2976ad..847421464e3 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/tsDataModelHooksEvaluationComparison.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/tsDataModelHooksEvaluationComparison.ts @@ -69,7 +69,6 @@ * across different datasets. */ -import _ from 'lodash'; import {sum} from 'lodash'; import {useEffect, useMemo, useRef, useState} from 'react'; @@ -479,17 +478,8 @@ const fetchEvaluationComparisonData = async ( const maybeDigest = parts[1]; if (maybeDigest != null && !maybeDigest.includes('/')) { const rowDigest = maybeDigest; - const possiblePredictNames = [ - 'predict', - 'infer', - 'forward', - 'invoke', - ]; const isProbablyPredictCall = - (_.some(possiblePredictNames, name => - traceCall.op_name.includes(`.${name}:`) - ) && - modelRefs.includes(traceCall.inputs.self)) || + modelRefs.includes(traceCall.inputs.self) || modelRefs.includes(traceCall.op_name); const isProbablyScoreCall = scorerRefs.has(traceCall.op_name); From bf39669bf3ebba977e9b6cd9224c9abce22a3a62 Mon Sep 17 00:00:00 2001 From: Griffin Tarpenning Date: Tue, 10 Dec 2024 16:36:47 -0800 Subject: [PATCH 29/52] fix(ui): Evaluations -- normalize radar data, show real values for bar charts (#3199) --- .../SummaryPlotsSection.tsx | 94 ++++++++++--------- 1 file changed, 50 insertions(+), 44 deletions(-) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/SummaryPlotsSection.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/SummaryPlotsSection.tsx index c23ffcce04d..02c456df850 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/SummaryPlotsSection.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/SummaryPlotsSection.tsx @@ -29,7 +29,7 @@ export const SummaryPlots: React.FC<{ state: EvaluationComparisonState; setSelectedMetrics: (newModel: Record) => void; }> = ({state, setSelectedMetrics}) => { - const {radarData, allMetricNames} = useNormalizedPlotDataFromMetrics(state); + const {radarData, allMetricNames} = usePlotDataFromMetrics(state); const {selectedMetrics} = state; // Initialize selectedMetrics if null @@ -237,11 +237,10 @@ const useFilteredData = ( return data; }, [radarData, selectedMetrics]); -function getMetricValuesFromRadarData(radarData: RadarPlotData): { +function getMetricValuesMap(radarData: RadarPlotData): { [metric: string]: number[]; } { const metricValues: {[metric: string]: number[]} = {}; - // Gather all values for each metric Object.values(radarData).forEach(callData => { Object.entries(callData.metrics).forEach(([metric, value]) => { if (!metricValues[metric]) { @@ -253,37 +252,54 @@ function getMetricValuesFromRadarData(radarData: RadarPlotData): { return metricValues; } -function getMetricMinsFromRadarData(radarData: RadarPlotData): { - [metric: string]: number; +function normalizeMetricValues(values: number[]): { + normalizedValues: number[]; + normalizer: number; } { - const metricValues = getMetricValuesFromRadarData(radarData); - const metricMins: {[metric: string]: number} = {}; - Object.entries(metricValues).forEach(([metric, values]) => { - metricMins[metric] = Math.min(...values); - }); - return metricMins; + const min = Math.min(...values); + const max = Math.max(...values); + + if (min === max) { + return { + normalizedValues: values.map(() => 0.5), + normalizer: 1, + }; + } + + // Handle negative values by shifting + const shiftedValues = min < 0 ? values.map(v => v - min) : values; + const maxValue = min < 0 ? max - min : max; + + const maxPower = Math.ceil(Math.log2(maxValue)); + const normalizer = Math.pow(2, maxPower); + + return { + normalizedValues: shiftedValues.map(v => v / normalizer), + normalizer, + }; } -function normalizeDataForRadarPlot(radarData: RadarPlotData): RadarPlotData { - const metricMins = getMetricMinsFromRadarData(radarData); +function normalizeDataForRadarPlot( + radarDataOriginal: RadarPlotData +): RadarPlotData { + const radarData = Object.fromEntries( + Object.entries(radarDataOriginal).map(([callId, callData]) => [ + callId, + {...callData, metrics: {...callData.metrics}}, + ]) + ); - const normalizedData: RadarPlotData = {}; - Object.entries(radarData).forEach(([callId, callData]) => { - normalizedData[callId] = { - name: callData.name, - color: callData.color, - metrics: {}, - }; + const metricValues = getMetricValuesMap(radarData); - Object.entries(callData.metrics).forEach(([metric, value]) => { - const min = metricMins[metric]; - // Only shift values if there are negative values - const normalizedValue = min < 0 ? value - min : value; - normalizedData[callId].metrics[metric] = normalizedValue; + // Normalize each metric independently + Object.entries(metricValues).forEach(([metric, values]) => { + const {normalizedValues} = normalizeMetricValues(values); + Object.values(radarData).forEach((callData, index) => { + callData.metrics[metric] = normalizedValues[index]; }); }); - return normalizedData; + return radarData; } const useBarPlotData = (filteredData: RadarPlotData) => @@ -317,7 +333,9 @@ const useBarPlotData = (filteredData: RadarPlotData) => type: 'bar', y: metricBin.values, x: metricBin.callIds, - text: metricBin.values.map(value => value.toFixed(3)), + text: metricBin.values.map(value => + Number.isInteger(value) ? value.toString() : value.toFixed(3) + ), textposition: 'outside', textfont: {size: 14, color: 'black'}, name: metric, @@ -408,16 +426,7 @@ const usePaginatedPlots = ( return {plotsToShow, totalPlots, startIndex, endIndex, totalPages}; }; -function normalizeValues(values: Array): number[] { - // find the max value - // find the power of 2 that is greater than the max value - // divide all values by that power of 2 - const maxVal = Math.max(...(values.filter(v => v !== undefined) as number[])); - const maxPower = Math.ceil(Math.log2(maxVal)); - return values.map(val => (val ? val / 2 ** maxPower : 0)); -} - -const useNormalizedPlotDataFromMetrics = ( +const usePlotDataFromMetrics = ( state: EvaluationComparisonState ): {radarData: RadarPlotData; allMetricNames: Set} => { const compositeMetrics = useMemo(() => { @@ -428,7 +437,7 @@ const useNormalizedPlotDataFromMetrics = ( }, [state]); return useMemo(() => { - const normalizedMetrics = Object.values(compositeMetrics) + const metrics = Object.values(compositeMetrics) .map(scoreGroup => Object.values(scoreGroup.metrics)) .flat() .map(metric => { @@ -449,11 +458,8 @@ const useNormalizedPlotDataFromMetrics = ( return val; } }); - const normalizedValues = normalizeValues(values); const evalScores: {[evalCallId: string]: number | undefined} = - Object.fromEntries( - callIds.map((key, i) => [key, normalizedValues[i]]) - ); + Object.fromEntries(callIds.map((key, i) => [key, values[i]])); const metricLabel = flattenedDimensionPath( Object.values(metric.scorerRefs)[0].metric @@ -472,7 +478,7 @@ const useNormalizedPlotDataFromMetrics = ( name: evalCall.name, color: evalCall.color, metrics: Object.fromEntries( - normalizedMetrics.map(metric => { + metrics.map(metric => { return [ metric.metricLabel, metric.evalScores[evalCall.callId] ?? 0, @@ -483,7 +489,7 @@ const useNormalizedPlotDataFromMetrics = ( ]; }) ); - const allMetricNames = new Set(normalizedMetrics.map(m => m.metricLabel)); + const allMetricNames = new Set(metrics.map(m => m.metricLabel)); return {radarData, allMetricNames}; }, [callIds, compositeMetrics, state.data.evaluationCalls]); }; From af81503aa330c0659a8c65b316b1071782a543e6 Mon Sep 17 00:00:00 2001 From: Jeff Raubitschek Date: Wed, 11 Dec 2024 12:34:22 -0800 Subject: [PATCH 30/52] chore(weave): update db to use replicated tables (#3148) --- .../test_clickhouse_trace_server_migrator.py | 230 ++++++++++++++++++ .../clickhouse_trace_server_migrator.py | 132 ++++++++-- 2 files changed, 336 insertions(+), 26 deletions(-) create mode 100644 tests/trace_server/test_clickhouse_trace_server_migrator.py diff --git a/tests/trace_server/test_clickhouse_trace_server_migrator.py b/tests/trace_server/test_clickhouse_trace_server_migrator.py new file mode 100644 index 00000000000..3a6a92f2479 --- /dev/null +++ b/tests/trace_server/test_clickhouse_trace_server_migrator.py @@ -0,0 +1,230 @@ +import types +from unittest.mock import Mock, call, patch + +import pytest + +from weave.trace_server import clickhouse_trace_server_migrator as trace_server_migrator +from weave.trace_server.clickhouse_trace_server_migrator import MigrationError + + +@pytest.fixture +def mock_costs(): + with patch( + "weave.trace_server.costs.insert_costs.should_insert_costs", return_value=False + ) as mock_should_insert: + with patch( + "weave.trace_server.costs.insert_costs.get_current_costs", return_value=[] + ) as mock_get_costs: + yield + + +@pytest.fixture +def migrator(): + ch_client = Mock() + migrator = trace_server_migrator.ClickHouseTraceServerMigrator(ch_client) + migrator._get_migration_status = Mock() + migrator._get_migrations = Mock() + migrator._determine_migrations_to_apply = Mock() + migrator._update_migration_status = Mock() + ch_client.command.reset_mock() + return migrator + + +def test_apply_migrations_with_target_version(mock_costs, migrator, tmp_path): + # Setup + migrator._get_migration_status.return_value = { + "curr_version": 1, + "partially_applied_version": None, + } + migrator._get_migrations.return_value = { + "1": {"up": "1.up.sql", "down": "1.down.sql"}, + "2": {"up": "2.up.sql", "down": "2.down.sql"}, + } + migrator._determine_migrations_to_apply.return_value = [(2, "2.up.sql")] + + # Create a temporary migration file + migration_dir = tmp_path / "migrations" + migration_dir.mkdir() + migration_file = migration_dir / "2.up.sql" + migration_file.write_text( + "CREATE TABLE test1 (id Int32);\nCREATE TABLE test2 (id Int32);" + ) + + # Mock the migration directory path + with patch("os.path.dirname") as mock_dirname: + mock_dirname.return_value = str(tmp_path) + + # Execute + migrator.apply_migrations("test_db", target_version=2) + + # Verify + migrator._get_migration_status.assert_called_once_with("test_db") + migrator._get_migrations.assert_called_once() + migrator._determine_migrations_to_apply.assert_called_once_with( + 1, migrator._get_migrations.return_value, 2 + ) + + # Verify migration execution + assert migrator._update_migration_status.call_count == 2 + migrator._update_migration_status.assert_has_calls( + [call("test_db", 2, is_start=True), call("test_db", 2, is_start=False)] + ) + + # Verify the actual SQL commands were executed + ch_client = migrator.ch_client + assert ch_client.command.call_count == 2 + ch_client.command.assert_has_calls( + [call("CREATE TABLE test1 (id Int32)"), call("CREATE TABLE test2 (id Int32)")] + ) + + +def test_execute_migration_command(mock_costs, migrator): + # Setup + ch_client = migrator.ch_client + ch_client.database = "original_db" + + # Execute + migrator._execute_migration_command("test_db", "CREATE TABLE test (id Int32)") + + # Verify + assert ch_client.database == "original_db" # Should restore original database + ch_client.command.assert_called_once_with("CREATE TABLE test (id Int32)") + + +def test_migration_replicated(mock_costs, migrator): + ch_client = migrator.ch_client + orig = "CREATE TABLE test (id String, project_id String) ENGINE = MergeTree ORDER BY (project_id, id);" + migrator._execute_migration_command("test_db", orig) + ch_client.command.assert_called_once_with(orig) + + +def test_update_migration_status(mock_costs, migrator): + # Don't mock _update_migration_status for this test + migrator._update_migration_status = types.MethodType( + trace_server_migrator.ClickHouseTraceServerMigrator._update_migration_status, + migrator, + ) + + # Test start of migration + migrator._update_migration_status("test_db", 2, is_start=True) + migrator.ch_client.command.assert_called_with( + "ALTER TABLE db_management.migrations UPDATE partially_applied_version = 2 WHERE db_name = 'test_db'" + ) + + # Test end of migration + migrator._update_migration_status("test_db", 2, is_start=False) + migrator.ch_client.command.assert_called_with( + "ALTER TABLE db_management.migrations UPDATE curr_version = 2, partially_applied_version = NULL WHERE db_name = 'test_db'" + ) + + +def test_is_safe_identifier(mock_costs, migrator): + # Valid identifiers + assert migrator._is_safe_identifier("test_db") + assert migrator._is_safe_identifier("my_db123") + assert migrator._is_safe_identifier("db.table") + + # Invalid identifiers + assert not migrator._is_safe_identifier("test-db") + assert not migrator._is_safe_identifier("db;") + assert not migrator._is_safe_identifier("db'name") + assert not migrator._is_safe_identifier("db/*") + + +def test_create_db_sql_validation(mock_costs, migrator): + # Test invalid database name + with pytest.raises(MigrationError, match="Invalid database name"): + migrator._create_db_sql("test;db") + + # Test replicated mode with invalid values + migrator.replicated = True + migrator.replicated_cluster = "test;cluster" + with pytest.raises(MigrationError, match="Invalid cluster name"): + migrator._create_db_sql("test_db") + + migrator.replicated_cluster = "test_cluster" + migrator.replicated_path = "/clickhouse/bad;path/{db}" + with pytest.raises(MigrationError, match="Invalid replicated path"): + migrator._create_db_sql("test_db") + + +def test_create_db_sql_non_replicated(mock_costs, migrator): + # Test non-replicated mode + migrator.replicated = False + sql = migrator._create_db_sql("test_db") + assert sql.strip() == "CREATE DATABASE IF NOT EXISTS test_db" + + +def test_create_db_sql_replicated(mock_costs, migrator): + # Test replicated mode + migrator.replicated = True + migrator.replicated_path = "/clickhouse/tables/{db}" + migrator.replicated_cluster = "test_cluster" + + sql = migrator._create_db_sql("test_db") + expected = """ + CREATE DATABASE IF NOT EXISTS test_db ON CLUSTER test_cluster ENGINE=Replicated('/clickhouse/tables/test_db', '{shard}', '{replica}') + """.strip() + assert sql.strip() == expected + + +def test_format_replicated_sql_non_replicated(mock_costs, migrator): + # Test that SQL is unchanged when replicated=False + migrator.replicated = False + test_cases = [ + "CREATE TABLE test (id Int32) ENGINE = MergeTree", + "CREATE TABLE test (id Int32) ENGINE = SummingMergeTree", + "CREATE TABLE test (id Int32) ENGINE=ReplacingMergeTree", + ] + + for sql in test_cases: + assert migrator._format_replicated_sql(sql) == sql + + +def test_format_replicated_sql_replicated(mock_costs, migrator): + # Test that MergeTree engines are converted to Replicated variants + migrator.replicated = True + + test_cases = [ + ( + "CREATE TABLE test (id Int32) ENGINE = MergeTree", + "CREATE TABLE test (id Int32) ENGINE = ReplicatedMergeTree", + ), + ( + "CREATE TABLE test (id Int32) ENGINE = SummingMergeTree", + "CREATE TABLE test (id Int32) ENGINE = ReplicatedSummingMergeTree", + ), + ( + "CREATE TABLE test (id Int32) ENGINE=ReplacingMergeTree", + "CREATE TABLE test (id Int32) ENGINE = ReplicatedReplacingMergeTree", + ), + # Test with extra whitespace + ( + "CREATE TABLE test (id Int32) ENGINE = MergeTree", + "CREATE TABLE test (id Int32) ENGINE = ReplicatedMergeTree", + ), + # Test with parameters + ( + "CREATE TABLE test (id Int32) ENGINE = MergeTree()", + "CREATE TABLE test (id Int32) ENGINE = ReplicatedMergeTree()", + ), + ] + + for input_sql, expected_sql in test_cases: + assert migrator._format_replicated_sql(input_sql) == expected_sql + + +def test_format_replicated_sql_non_mergetree(mock_costs, migrator): + # Test that non-MergeTree engines are left unchanged + migrator.replicated = True + + test_cases = [ + "CREATE TABLE test (id Int32) ENGINE = Memory", + "CREATE TABLE test (id Int32) ENGINE = Log", + "CREATE TABLE test (id Int32) ENGINE = TinyLog", + # This should not be changed as it's not a complete word match + "CREATE TABLE test (id Int32) ENGINE = MyMergeTreeCustom", + ] + + for sql in test_cases: + assert migrator._format_replicated_sql(sql) == sql diff --git a/weave/trace_server/clickhouse_trace_server_migrator.py b/weave/trace_server/clickhouse_trace_server_migrator.py index 30dffe89365..4336630bf50 100644 --- a/weave/trace_server/clickhouse_trace_server_migrator.py +++ b/weave/trace_server/clickhouse_trace_server_migrator.py @@ -1,6 +1,7 @@ # Clickhouse Trace Server Manager import logging import os +import re from typing import Optional from clickhouse_connect.driver.client import Client as CHClient @@ -9,6 +10,11 @@ logger = logging.getLogger(__name__) +# These settings are only used when `replicated` mode is enabled for +# self managed clickhouse instances. +DEFAULT_REPLICATED_PATH = "/clickhouse/tables/{db}" +DEFAULT_REPLICATED_CLUSTER = "weave_cluster" + class MigrationError(RuntimeError): """Raised when a migration error occurs.""" @@ -16,15 +22,77 @@ class MigrationError(RuntimeError): class ClickHouseTraceServerMigrator: ch_client: CHClient + replicated: bool + replicated_path: str + replicated_cluster: str def __init__( self, ch_client: CHClient, + replicated: Optional[bool] = None, + replicated_path: Optional[str] = None, + replicated_cluster: Optional[str] = None, ): super().__init__() self.ch_client = ch_client + self.replicated = False if replicated is None else replicated + self.replicated_path = ( + DEFAULT_REPLICATED_PATH if replicated_path is None else replicated_path + ) + self.replicated_cluster = ( + DEFAULT_REPLICATED_CLUSTER + if replicated_cluster is None + else replicated_cluster + ) self._initialize_migration_db() + def _is_safe_identifier(self, value: str) -> bool: + """Check if a string is safe to use as an identifier in SQL.""" + return bool(re.match(r"^[a-zA-Z0-9_\.]+$", value)) + + def _format_replicated_sql(self, sql_query: str) -> str: + """Format SQL query to use replicated engines if replicated mode is enabled.""" + if not self.replicated: + return sql_query + + # Match "ENGINE = MergeTree" followed by word boundary + pattern = r"ENGINE\s*=\s*(\w+)?MergeTree\b" + + def replace_engine(match: re.Match[str]) -> str: + engine_prefix = match.group(1) or "" + return f"ENGINE = Replicated{engine_prefix}MergeTree" + + return re.sub(pattern, replace_engine, sql_query, flags=re.IGNORECASE) + + def _create_db_sql(self, db_name: str) -> str: + """Geneate SQL database create string for normal and replicated databases.""" + if not self._is_safe_identifier(db_name): + raise MigrationError(f"Invalid database name: {db_name}") + + replicated_engine = "" + replicated_cluster = "" + if self.replicated: + if not self._is_safe_identifier(self.replicated_cluster): + raise MigrationError(f"Invalid cluster name: {self.replicated_cluster}") + + replicated_path = self.replicated_path.replace("{db}", db_name) + if not all( + self._is_safe_identifier(part) + for part in replicated_path.split("/") + if part + ): + raise MigrationError(f"Invalid replicated path: {replicated_path}") + + replicated_cluster = f" ON CLUSTER {self.replicated_cluster}" + replicated_engine = ( + f" ENGINE=Replicated('{replicated_path}', '{{shard}}', '{{replica}}')" + ) + + create_db_sql = f""" + CREATE DATABASE IF NOT EXISTS {db_name}{replicated_cluster}{replicated_engine} + """ + return create_db_sql + def apply_migrations( self, target_db: str, target_version: Optional[int] = None ) -> None: @@ -46,20 +114,15 @@ def apply_migrations( return logger.info(f"Migrations to apply: {migrations_to_apply}") if status["curr_version"] == 0: - self.ch_client.command(f"CREATE DATABASE IF NOT EXISTS {target_db}") + self.ch_client.command(self._create_db_sql(target_db)) for target_version, migration_file in migrations_to_apply: self._apply_migration(target_db, target_version, migration_file) if should_insert_costs(status["curr_version"], target_version): insert_costs(self.ch_client, target_db) def _initialize_migration_db(self) -> None: - self.ch_client.command( - """ - CREATE DATABASE IF NOT EXISTS db_management - """ - ) - self.ch_client.command( - """ + self.ch_client.command(self._create_db_sql("db_management")) + create_table_sql = """ CREATE TABLE IF NOT EXISTS db_management.migrations ( db_name String, @@ -69,7 +132,7 @@ def _initialize_migration_db(self) -> None: ENGINE = MergeTree() ORDER BY (db_name) """ - ) + self.ch_client.command(self._format_replicated_sql(create_table_sql)) def _get_migration_status(self, db_name: str) -> dict: column_names = ["db_name", "curr_version", "partially_applied_version"] @@ -184,31 +247,48 @@ def _determine_migrations_to_apply( return [] + def _execute_migration_command(self, target_db: str, command: str) -> None: + """Execute a single migration command in the context of the target database.""" + command = command.strip() + if len(command) == 0: + return + curr_db = self.ch_client.database + self.ch_client.database = target_db + self.ch_client.command(self._format_replicated_sql(command)) + self.ch_client.database = curr_db + + def _update_migration_status( + self, target_db: str, target_version: int, is_start: bool = True + ) -> None: + """Update the migration status in db_management.migrations table.""" + if is_start: + self.ch_client.command( + f"ALTER TABLE db_management.migrations UPDATE partially_applied_version = {target_version} WHERE db_name = '{target_db}'" + ) + else: + self.ch_client.command( + f"ALTER TABLE db_management.migrations UPDATE curr_version = {target_version}, partially_applied_version = NULL WHERE db_name = '{target_db}'" + ) + def _apply_migration( self, target_db: str, target_version: int, migration_file: str ) -> None: logger.info(f"Applying migration {migration_file} to `{target_db}`") migration_dir = os.path.join(os.path.dirname(__file__), "migrations") migration_file_path = os.path.join(migration_dir, migration_file) + with open(migration_file_path) as f: migration_sql = f.read() - self.ch_client.command( - f""" - ALTER TABLE db_management.migrations UPDATE partially_applied_version = {target_version} WHERE db_name = '{target_db}' - """ - ) + + # Mark migration as partially applied + self._update_migration_status(target_db, target_version, is_start=True) + + # Execute each command in the migration migration_sub_commands = migration_sql.split(";") for command in migration_sub_commands: - command = command.strip() - if len(command) == 0: - continue - curr_db = self.ch_client.database - self.ch_client.database = target_db - self.ch_client.command(command) - self.ch_client.database = curr_db - self.ch_client.command( - f""" - ALTER TABLE db_management.migrations UPDATE curr_version = {target_version}, partially_applied_version = NULL WHERE db_name = '{target_db}' - """ - ) + self._execute_migration_command(target_db, command) + + # Mark migration as fully applied + self._update_migration_status(target_db, target_version, is_start=False) + logger.info(f"Migration {migration_file} applied to `{target_db}`") From 4c6db183be9b356dfc6781670c5b8f5fb6fd7532 Mon Sep 17 00:00:00 2001 From: Josiah Lee Date: Wed, 11 Dec 2024 12:54:59 -0800 Subject: [PATCH 31/52] add choices drawer (#3203) --- .../Browse3/pages/ChatView/ChoiceView.tsx | 8 +- .../Browse3/pages/ChatView/ChoicesDrawer.tsx | 104 ++++++++++++++++++ .../Browse3/pages/ChatView/ChoicesView.tsx | 39 ++++--- .../pages/ChatView/ChoicesViewCarousel.tsx | 27 +++-- .../pages/ChatView/ChoicesViewLinear.tsx | 73 ------------ .../Home/Browse3/pages/ChatView/types.ts | 2 - 6 files changed, 148 insertions(+), 105 deletions(-) create mode 100644 weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesDrawer.tsx delete mode 100644 weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesViewLinear.tsx diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoiceView.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoiceView.tsx index c8143f9549e..d1a2c59d5d0 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoiceView.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoiceView.tsx @@ -6,15 +6,21 @@ import {Choice} from './types'; type ChoiceViewProps = { choice: Choice; isStructuredOutput?: boolean; + isNested?: boolean; }; -export const ChoiceView = ({choice, isStructuredOutput}: ChoiceViewProps) => { +export const ChoiceView = ({ + choice, + isStructuredOutput, + isNested, +}: ChoiceViewProps) => { const {message} = choice; return ( ); diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesDrawer.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesDrawer.tsx new file mode 100644 index 00000000000..16d6897c27b --- /dev/null +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesDrawer.tsx @@ -0,0 +1,104 @@ +import {Box, Drawer} from '@mui/material'; +import {MOON_200} from '@wandb/weave/common/css/color.styles'; +import {Tag} from '@wandb/weave/components/Tag'; +import {Tailwind} from '@wandb/weave/components/Tailwind'; +import React from 'react'; + +import {Button} from '../../../../../Button'; +import {ChoiceView} from './ChoiceView'; +import {Choice} from './types'; + +type ChoicesDrawerProps = { + choices: Choice[]; + isStructuredOutput?: boolean; + isDrawerOpen: boolean; + setIsDrawerOpen: React.Dispatch>; + selectedChoiceIndex: number; + setSelectedChoiceIndex: (choiceIndex: number) => void; +}; + +export const ChoicesDrawer = ({ + choices, + isStructuredOutput, + isDrawerOpen, + setIsDrawerOpen, + selectedChoiceIndex, + setSelectedChoiceIndex, +}: ChoicesDrawerProps) => { + return ( + setIsDrawerOpen(false)} + title="Choices" + anchor="right" + sx={{ + '& .MuiDrawer-paper': {mt: '60px', width: '400px'}, + }}> + + + Responses + + + ) : ( + + )} +
+ +
+ ))} +
+ + + ); +}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesView.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesView.tsx index c22df7c63d7..5ddc7f12202 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesView.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesView.tsx @@ -1,9 +1,9 @@ import React, {useState} from 'react'; +import {ChoicesDrawer} from './ChoicesDrawer'; import {ChoicesViewCarousel} from './ChoicesViewCarousel'; -import {ChoicesViewLinear} from './ChoicesViewLinear'; import {ChoiceView} from './ChoiceView'; -import {Choice, ChoicesMode} from './types'; +import {Choice} from './types'; type ChoicesViewProps = { choices: Choice[]; @@ -14,7 +14,12 @@ export const ChoicesView = ({ choices, isStructuredOutput, }: ChoicesViewProps) => { - const [mode, setMode] = useState('linear'); + const [isDrawerOpen, setIsDrawerOpen] = useState(false); + const [localSelectedChoiceIndex, setLocalSelectedChoiceIndex] = useState(0); + + const handleSetSelectedChoiceIndex = (choiceIndex: number) => { + setLocalSelectedChoiceIndex(choiceIndex); + }; if (choices.length === 0) { return null; @@ -26,20 +31,20 @@ export const ChoicesView = ({ } return ( <> - {mode === 'linear' && ( - - )} - {mode === 'carousel' && ( - - )} + + ); }; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesViewCarousel.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesViewCarousel.tsx index f4a52fc6801..a34932dea17 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesViewCarousel.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesViewCarousel.tsx @@ -1,34 +1,37 @@ -import React, {useState} from 'react'; +import React from 'react'; import {Button} from '../../../../../Button'; import {ChoiceView} from './ChoiceView'; -import {Choice, ChoicesMode} from './types'; +import {Choice} from './types'; type ChoicesViewCarouselProps = { choices: Choice[]; isStructuredOutput?: boolean; - setMode: React.Dispatch>; + setIsDrawerOpen: React.Dispatch>; + selectedChoiceIndex: number; + setSelectedChoiceIndex: (choiceIndex: number) => void; }; export const ChoicesViewCarousel = ({ choices, isStructuredOutput, - setMode, + setIsDrawerOpen, + selectedChoiceIndex, + setSelectedChoiceIndex, }: ChoicesViewCarouselProps) => { - const [step, setStep] = useState(0); - const onNext = () => { - setStep((step + 1) % choices.length); + setSelectedChoiceIndex((selectedChoiceIndex + 1) % choices.length); }; const onBack = () => { - const newStep = step === 0 ? choices.length - 1 : step - 1; - setStep(newStep); + const newStep = + selectedChoiceIndex === 0 ? choices.length - 1 : selectedChoiceIndex - 1; + setSelectedChoiceIndex(newStep); }; return ( <>
@@ -37,7 +40,7 @@ export const ChoicesViewCarousel = ({ size="small" variant="quiet" icon="expand-uncollapse" - onClick={() => setMode('linear')} + onClick={() => setIsDrawerOpen(true)} tooltip="Switch to linear view" />
@@ -48,7 +51,7 @@ export const ChoicesViewCarousel = ({ size="small" onClick={onBack} /> - {step + 1} of {choices.length} + {selectedChoiceIndex + 1} of {choices.length}
- -
-
- - + } + /> ); }; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/MessagePanel.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/MessagePanel.tsx index 1e778727522..f570b2f6295 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/MessagePanel.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/MessagePanel.tsx @@ -18,6 +18,7 @@ type MessagePanelProps = { choiceIndex?: number; isNested?: boolean; pendingToolResponseId?: string; + messageHeader?: React.ReactNode; }; export const MessagePanel = ({ @@ -30,6 +31,7 @@ export const MessagePanel = ({ // If the tool call response is pending, the editor will be shown automatically // and on save the tool call response will be updated and sent to the LLM pendingToolResponseId, + messageHeader, }: MessagePanelProps) => { const [isShowingMore, setIsShowingMore] = useState(false); const [isOverflowing, setIsOverflowing] = useState(false); @@ -116,6 +118,7 @@ export const MessagePanel = ({ 'max-h-[400px]': !isShowingMore, 'max-h-full': isShowingMore, })}> + {messageHeader} {isPlayground && editorHeight ? ( { + console.log('playgroundStates', playgroundStates); const [chatText, setChatText] = useState(''); const [isLoading, setIsLoading] = useState(false); diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/useChatFunctions.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/useChatFunctions.tsx index 804670a1dc3..0ce3ad02b51 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/useChatFunctions.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/useChatFunctions.tsx @@ -43,6 +43,8 @@ export const useChatFunctions = ( messageIndex: number, newMessage: Message ) => { + console.log('editMessage', callIndex, messageIndex, newMessage); + setPlaygroundStateField(callIndex, 'traceCall', prevTraceCall => { const newTraceCall = clearTraceCall( cloneDeep(prevTraceCall as OptionalTraceCallSchema) @@ -106,6 +108,7 @@ export const useChatFunctions = ( choiceIndex: number, newChoice: Message ) => { + console.log('editChoice', callIndex, choiceIndex, newChoice); setPlaygroundStateField(callIndex, 'traceCall', prevTraceCall => { const newTraceCall = clearTraceCall( cloneDeep(prevTraceCall as OptionalTraceCallSchema) From a0f12639451979efc01ef3b31767444542688ea4 Mon Sep 17 00:00:00 2001 From: Andrew Truong Date: Wed, 11 Dec 2024 16:56:46 -0500 Subject: [PATCH 38/52] chore(weave): Add generic iterator for trace server API objects (#3177) --- weave/trace/weave_client.py | 207 ++++++++++++++++++++++-------------- 1 file changed, 129 insertions(+), 78 deletions(-) diff --git a/weave/trace/weave_client.py b/weave/trace/weave_client.py index 0eca3fcbedb..1d5d54b9b23 100644 --- a/weave/trace/weave_client.py +++ b/weave/trace/weave_client.py @@ -10,7 +10,7 @@ from collections.abc import Iterator, Sequence from concurrent.futures import Future from functools import lru_cache -from typing import Any, Callable, cast +from typing import Any, Callable, Generic, Protocol, TypeVar, cast, overload import pydantic from requests import HTTPError @@ -90,6 +90,128 @@ logger = logging.getLogger(__name__) +T = TypeVar("T") +R = TypeVar("R", covariant=True) + + +class FetchFunc(Protocol[T]): + def __call__(self, offset: int, limit: int) -> list[T]: ... + + +TransformFunc = Callable[[T], R] + + +class PaginatedIterator(Generic[T, R]): + """An iterator that fetches pages of items from a server and optionally transforms them + into a more user-friendly type.""" + + def __init__( + self, + fetch_func: FetchFunc[T], + page_size: int = 1000, + transform_func: TransformFunc[T, R] | None = None, + ) -> None: + self.fetch_func = fetch_func + self.page_size = page_size + self.transform_func = transform_func + + if page_size <= 0: + raise ValueError("page_size must be greater than 0") + + @lru_cache + def _fetch_page(self, index: int) -> list[T]: + return self.fetch_func(index * self.page_size, self.page_size) + + @overload + def _get_one(self: PaginatedIterator[T, T], index: int) -> T: ... + @overload + def _get_one(self: PaginatedIterator[T, R], index: int) -> R: ... + def _get_one(self, index: int) -> T | R: + if index < 0: + raise IndexError("Negative indexing not supported") + + page_index = index // self.page_size + page_offset = index % self.page_size + + page = self._fetch_page(page_index) + if page_offset >= len(page): + raise IndexError(f"Index {index} out of range") + + res = page[page_offset] + if transform := self.transform_func: + return transform(res) + return res + + @overload + def _get_slice(self: PaginatedIterator[T, T], key: slice) -> Iterator[T]: ... + @overload + def _get_slice(self: PaginatedIterator[T, R], key: slice) -> Iterator[R]: ... + def _get_slice(self, key: slice) -> Iterator[T] | Iterator[R]: + if (start := key.start or 0) < 0: + raise ValueError("Negative start not supported") + if (stop := key.stop) is not None and stop < 0: + raise ValueError("Negative stop not supported") + if (step := key.step or 1) < 0: + raise ValueError("Negative step not supported") + + i = start + while stop is None or i < stop: + try: + yield self._get_one(i) + except IndexError: + break + i += step + + @overload + def __getitem__(self: PaginatedIterator[T, T], key: int) -> T: ... + @overload + def __getitem__(self: PaginatedIterator[T, R], key: int) -> R: ... + @overload + def __getitem__(self: PaginatedIterator[T, T], key: slice) -> list[T]: ... + @overload + def __getitem__(self: PaginatedIterator[T, R], key: slice) -> list[R]: ... + def __getitem__(self, key: slice | int) -> T | R | list[T] | list[R]: + if isinstance(key, slice): + return list(self._get_slice(key)) + return self._get_one(key) + + @overload + def __iter__(self: PaginatedIterator[T, T]) -> Iterator[T]: ... + @overload + def __iter__(self: PaginatedIterator[T, R]) -> Iterator[R]: ... + def __iter__(self) -> Iterator[T] | Iterator[R]: + return self._get_slice(slice(0, None, 1)) + + +# TODO: should be Call, not WeaveObject +CallsIter = PaginatedIterator[CallSchema, WeaveObject] + + +def _make_calls_iterator( + server: TraceServerInterface, + project_id: str, + filter: CallsFilter, + include_costs: bool = False, +) -> CallsIter: + def fetch_func(offset: int, limit: int) -> list[CallSchema]: + response = server.calls_query( + CallsQueryReq( + project_id=project_id, + filter=filter, + offset=offset, + limit=limit, + include_costs=include_costs, + ) + ) + return response.calls + + # TODO: Should be Call, not WeaveObject + def transform_func(call: CallSchema) -> WeaveObject: + entity, project = project_id.split("/") + return make_client_call(entity, project, call, server) + + return PaginatedIterator(fetch_func, transform_func=transform_func) + class OpNameError(ValueError): """Raised when an op name is invalid.""" @@ -284,7 +406,7 @@ def children(self) -> CallsIter: ) client = weave_client_context.require_weave_client() - return CallsIter( + return _make_calls_iterator( client.server, self.project_id, CallsFilter(parent_ids=[self.id]), @@ -362,80 +484,6 @@ def _apply_scorer(self, scorer_op: Op) -> None: ) -class CallsIter: - server: TraceServerInterface - filter: CallsFilter - include_costs: bool - - def __init__( - self, - server: TraceServerInterface, - project_id: str, - filter: CallsFilter, - include_costs: bool = False, - ) -> None: - self.server = server - self.project_id = project_id - self.filter = filter - self._page_size = 1000 - self.include_costs = include_costs - - # seems like this caching should be on the server, but it's here for now... - @lru_cache - def _fetch_page(self, index: int) -> list[CallSchema]: - # caching in here means that any other CallsIter objects would also - # benefit from the cache - response = self.server.calls_query( - CallsQueryReq( - project_id=self.project_id, - filter=self.filter, - offset=index * self._page_size, - limit=self._page_size, - include_costs=self.include_costs, - ) - ) - return response.calls - - def _get_one(self, index: int) -> WeaveObject: - if index < 0: - raise IndexError("Negative indexing not supported") - - page_index = index // self._page_size - page_offset = index % self._page_size - - calls = self._fetch_page(page_index) - if page_offset >= len(calls): - raise IndexError(f"Index {index} out of range") - - call = calls[page_offset] - entity, project = self.project_id.split("/") - return make_client_call(entity, project, call, self.server) - - def _get_slice(self, key: slice) -> Iterator[WeaveObject]: - if (start := key.start or 0) < 0: - raise ValueError("Negative start not supported") - if (stop := key.stop) is not None and stop < 0: - raise ValueError("Negative stop not supported") - if (step := key.step or 1) < 0: - raise ValueError("Negative step not supported") - - i = start - while stop is None or i < stop: - try: - yield self._get_one(i) - except IndexError: - break - i += step - - def __getitem__(self, key: slice | int) -> WeaveObject | list[WeaveObject]: - if isinstance(key, slice): - return list(self._get_slice(key)) - return self._get_one(key) - - def __iter__(self) -> Iterator[WeaveObject]: - return self._get_slice(slice(0, None, 1)) - - def make_client_call( entity: str, project: str, server_call: CallSchema, server: TraceServerInterface ) -> WeaveObject: @@ -642,8 +690,11 @@ def get_calls( if filter is None: filter = CallsFilter() - return CallsIter( - self.server, self._project_id(), filter, include_costs or False + return _make_calls_iterator( + self.server, + self._project_id(), + filter, + include_costs, ) @deprecated(new_name="get_calls") From 0525621c503ee6082a2d809f17bef64b04ea71ef Mon Sep 17 00:00:00 2001 From: Andrew Truong Date: Wed, 11 Dec 2024 17:05:51 -0500 Subject: [PATCH 39/52] feat(weave): Support op configuration for autopatched functions (starting with OpenAI) (#3197) --- tests/conftest.py | 12 +- .../test_configuration_with_dicts.yaml | 102 +++++++++++++ ...est_disabled_integration_doesnt_patch.yaml | 102 +++++++++++++ .../test_enabled_integration_patches.yaml | 102 +++++++++++++ .../test_passthrough_op_kwargs.yaml | 102 +++++++++++++ tests/integrations/openai/test_autopatch.py | 116 +++++++++++++++ weave/integrations/openai/openai_sdk.py | 136 +++++++++++------- weave/scorers/llm_utils.py | 4 - weave/trace/api.py | 9 +- weave/trace/autopatch.py | 59 +++++++- weave/trace/patcher.py | 8 ++ weave/trace/weave_init.py | 6 +- 12 files changed, 689 insertions(+), 69 deletions(-) create mode 100644 tests/integrations/openai/cassettes/test_autopatch/test_configuration_with_dicts.yaml create mode 100644 tests/integrations/openai/cassettes/test_autopatch/test_disabled_integration_doesnt_patch.yaml create mode 100644 tests/integrations/openai/cassettes/test_autopatch/test_enabled_integration_patches.yaml create mode 100644 tests/integrations/openai/cassettes/test_autopatch/test_passthrough_op_kwargs.yaml create mode 100644 tests/integrations/openai/test_autopatch.py diff --git a/tests/conftest.py b/tests/conftest.py index b28187a3833..85e9b53c36b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -477,7 +477,9 @@ def __getattribute__(self, name): return ServerRecorder(server) -def create_client(request) -> weave_init.InitializedClient: +def create_client( + request, autopatch_settings: typing.Optional[autopatch.AutopatchSettings] = None +) -> weave_init.InitializedClient: inited_client = None weave_server_flag = request.config.getoption("--weave-server") server: tsi.TraceServerInterface @@ -513,7 +515,7 @@ def create_client(request) -> weave_init.InitializedClient: entity, project, make_server_recorder(server) ) inited_client = weave_init.InitializedClient(client) - autopatch.autopatch() + autopatch.autopatch(autopatch_settings) return inited_client @@ -527,6 +529,7 @@ def client(request): yield inited_client.client finally: inited_client.reset() + autopatch.reset_autopatch() @pytest.fixture() @@ -534,12 +537,13 @@ def client_creator(request): """This fixture is useful for delaying the creation of the client (ex. when you want to set settings first)""" @contextlib.contextmanager - def client(): - inited_client = create_client(request) + def client(autopatch_settings: typing.Optional[autopatch.AutopatchSettings] = None): + inited_client = create_client(request, autopatch_settings) try: yield inited_client.client finally: inited_client.reset() + autopatch.reset_autopatch() yield client diff --git a/tests/integrations/openai/cassettes/test_autopatch/test_configuration_with_dicts.yaml b/tests/integrations/openai/cassettes/test_autopatch/test_configuration_with_dicts.yaml new file mode 100644 index 00000000000..7245829a0b3 --- /dev/null +++ b/tests/integrations/openai/cassettes/test_autopatch/test_configuration_with_dicts.yaml @@ -0,0 +1,102 @@ +interactions: +- request: + body: '{"messages":[{"role":"user","content":"tell me a joke"}],"model":"gpt-4o"}' + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate, zstd + connection: + - keep-alive + content-length: + - '74' + content-type: + - application/json + host: + - api.openai.com + user-agent: + - OpenAI/Python 1.57.2 + x-stainless-arch: + - arm64 + x-stainless-async: + - 'false' + x-stainless-lang: + - python + x-stainless-os: + - MacOS + x-stainless-package-version: + - 1.57.2 + x-stainless-retry-count: + - '0' + x-stainless-runtime: + - CPython + x-stainless-runtime-version: + - 3.13.0rc2 + method: POST + uri: https://api.openai.com/v1/chat/completions + response: + body: + string: !!binary | + H4sIAAAAAAAAAwAAAP//jFJNa9wwEL37V0x1yWVd7P3KspcSSKE5thvooSlGK40tJbJGSOOSNOx/ + L/Z+2KEp9KLDe/Me743mNQMQVostCGUkqza4/Eav1f1O/v4aVvvbL/Pdt7tVHUp1s/vsnhsx6xW0 + f0TFZ9VHRW1wyJb8kVYRJWPvWl4vFpvNYl2UA9GSRtfLmsD5kvJ5MV/mxSYv1iehIaswiS38yAAA + Xoe3j+g1PostFLMz0mJKskGxvQwBiEiuR4RMySaWnsVsJBV5Rj+k/m5eQJO/YkhP6JDJJ6htYxhQ + KgPEBuOnB//g7w2eJ438hcAGoek4fZgaR6y7JPtevnPuhB8uSR01IdI+nfgLXltvk6kiykS+T5WY + ghjYQwbwc9hI96akCJHawBXTE/resCyPdmL8ggm5PJFMLN2Iz1ezd9wqjSytS5ONCiWVQT0qx/XL + TluaENmk899h3vM+9ra++R/7kVAKA6OuQkRt1dvC41jE/kD/NXbZ8RBYpJfE2Fa19Q3GEO3xRupQ + qWslC9xLJUV2yP4AAAD//wMA4O+DUSwDAAA= + headers: + CF-Cache-Status: + - DYNAMIC + CF-RAY: + - 8f01fe3aabd037cf-YYZ + Connection: + - keep-alive + Content-Encoding: + - gzip + Content-Type: + - application/json + Date: + - Wed, 11 Dec 2024 02:20:01 GMT + Server: + - cloudflare + Set-Cookie: + - __cf_bm=xqe_jHZdTV5LijJQYQ3GMY5MjtVrCyxbFO4glgLvgD0-1733883601-1.0.1.1-p.DDUca_cHppJu2hXzzA0CXU1mtalxHUNfBWVgPIQj.UkU603pbNscCvSIi4_Zjlz9Zuc3.hjlvoyZxcDBJTsw; + path=/; expires=Wed, 11-Dec-24 02:50:01 GMT; domain=.api.openai.com; HttpOnly; + Secure; SameSite=None + - _cfuvid=WEjxXqkGswaEDhllTROGX_go9tgaWNJcUJ3cCd50xDI-1733883601764-0.0.1.1-604800000; + path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None + Transfer-Encoding: + - chunked + X-Content-Type-Options: + - nosniff + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + openai-organization: + - wandb + openai-processing-ms: + - '607' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + x-ratelimit-limit-requests: + - '10000' + x-ratelimit-limit-tokens: + - '30000000' + x-ratelimit-remaining-requests: + - '9999' + x-ratelimit-remaining-tokens: + - '29999979' + x-ratelimit-reset-requests: + - 6ms + x-ratelimit-reset-tokens: + - 0s + x-request-id: + - req_8592a74b531c806f65c63c7471101cb6 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/integrations/openai/cassettes/test_autopatch/test_disabled_integration_doesnt_patch.yaml b/tests/integrations/openai/cassettes/test_autopatch/test_disabled_integration_doesnt_patch.yaml new file mode 100644 index 00000000000..1895cdcd5f2 --- /dev/null +++ b/tests/integrations/openai/cassettes/test_autopatch/test_disabled_integration_doesnt_patch.yaml @@ -0,0 +1,102 @@ +interactions: +- request: + body: '{"messages":[{"role":"user","content":"tell me a joke"}],"model":"gpt-4o"}' + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate, zstd + connection: + - keep-alive + content-length: + - '74' + content-type: + - application/json + host: + - api.openai.com + user-agent: + - OpenAI/Python 1.57.2 + x-stainless-arch: + - arm64 + x-stainless-async: + - 'false' + x-stainless-lang: + - python + x-stainless-os: + - MacOS + x-stainless-package-version: + - 1.57.2 + x-stainless-retry-count: + - '0' + x-stainless-runtime: + - CPython + x-stainless-runtime-version: + - 3.13.0rc2 + method: POST + uri: https://api.openai.com/v1/chat/completions + response: + body: + string: !!binary | + H4sIAAAAAAAAAwAAAP//jFLJbtswEL3rK6a85GIV8lYvl6KX9NQFrYEckkKgyZHImOII5KiJEfjf + C8qLHDQFeuHhbXgzw5cMQFgt1iCUkaya1uWf9Hyuv5H29ebu89Pt80Z9//H0RS2nX/e3LEbJQdtH + VHx2vVfUtA7Zkj/SKqBkTKnjxXS6XCwWk3FPNKTRJVvdcj6jfFJMZnmxzIsPJ6MhqzCKNdxnAAAv + /Zsqeo3PYg3F6Iw0GKOsUawvIgARyCVEyBhtZOmPdU+kIs/o+9Y/u4AjMBjwJoIEZ2vDuUEZGDU8 + 0g6hogB76tYP/sHfmT1o8jcMcYcOmXyEKlkApTJAbDB8TMKNwbPSyN8IbBDqjuO76xoBqy7KtAXf + OXfCD5e5HNVtoG088Re8st5GUwaUkXyaITK1omcPGcCvfn/dq5WINlDTcsm0Q58Cx+NjnBgONpCT + 2YlkYukGfDofvZFWamRpXbzav1BSGdSDcziW7LSlKyK7mvnvMm9lH+e2vv6f+IFQCltGXbYBtVWv + Bx5kAdN3/pfssuO+sIj7yNiUlfU1hjbY44+q2nKl54XSq1WxFdkh+wMAAP//AwAWTTnuWgMAAA== + headers: + CF-Cache-Status: + - DYNAMIC + CF-RAY: + - 8f016eadbff439d2-YYZ + Connection: + - keep-alive + Content-Encoding: + - gzip + Content-Type: + - application/json + Date: + - Wed, 11 Dec 2024 00:42:01 GMT + Server: + - cloudflare + Set-Cookie: + - __cf_bm=8FO1yMjc3pMQWRpWrkIe5mcs39GLeqQPmgHQq0YTT8s-1733877721-1.0.1.1-i4G06DBN08aH1F1H73U_TB9OLK3jLsV1jXydB1cQ4Hqx7I.r8xDn.7hFRZe2hy3D_nABTG1nDcdDoXL_wYiqug; + path=/; expires=Wed, 11-Dec-24 01:12:01 GMT; domain=.api.openai.com; HttpOnly; + Secure; SameSite=None + - _cfuvid=jxwySgtriPkUP8L2os1nb_gRq_SSUo3yWFUyJmHPmGY-1733877721989-0.0.1.1-604800000; + path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None + Transfer-Encoding: + - chunked + X-Content-Type-Options: + - nosniff + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + openai-organization: + - wandb + openai-processing-ms: + - '652' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + x-ratelimit-limit-requests: + - '10000' + x-ratelimit-limit-tokens: + - '30000000' + x-ratelimit-remaining-requests: + - '9999' + x-ratelimit-remaining-tokens: + - '29999979' + x-ratelimit-reset-requests: + - 6ms + x-ratelimit-reset-tokens: + - 0s + x-request-id: + - req_1c86d4fda2ad715edfd41bcd2f4bdd89 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/integrations/openai/cassettes/test_autopatch/test_enabled_integration_patches.yaml b/tests/integrations/openai/cassettes/test_autopatch/test_enabled_integration_patches.yaml new file mode 100644 index 00000000000..f0cdca54158 --- /dev/null +++ b/tests/integrations/openai/cassettes/test_autopatch/test_enabled_integration_patches.yaml @@ -0,0 +1,102 @@ +interactions: +- request: + body: '{"messages":[{"role":"user","content":"tell me a joke"}],"model":"gpt-4o"}' + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate, zstd + connection: + - keep-alive + content-length: + - '74' + content-type: + - application/json + host: + - api.openai.com + user-agent: + - OpenAI/Python 1.57.2 + x-stainless-arch: + - arm64 + x-stainless-async: + - 'false' + x-stainless-lang: + - python + x-stainless-os: + - MacOS + x-stainless-package-version: + - 1.57.2 + x-stainless-retry-count: + - '0' + x-stainless-runtime: + - CPython + x-stainless-runtime-version: + - 3.13.0rc2 + method: POST + uri: https://api.openai.com/v1/chat/completions + response: + body: + string: !!binary | + H4sIAAAAAAAAAwAAAP//jFLBjtMwEL3nKwZfuDQoTXc3VS8IBOIACCQOHHZR5NrTxDTxWJ4J2rDq + v6Ok2SYrFomLD+/Ne3pvxg8JgHJW7UCZWotpQ5O+sdfXaL9+/PDp+L6gG3ffyudv735v+y82eLUa + FLT/iUYeVa8MtaFBcTTRJqIWHFzXxWazLYoiz0eiJYvNIKuCpFeU5ll+lWbbNLuZhDU5g6x2cJsA + ADyM7xDRW7xXO8hWj0iLzLpCtbsMAahIzYAozexYtBe1mklDXtCPqb/XPVjyLwXYOPTiWBgkdiyg + hVp+fefv/Fs0umMEqbGHVh8RugD4C2MvtfPVi6V3xEPHeqjmu6aZ8NMlbENViLTnib/gB+cd12VE + zeSHYCwU1MieEoAf41K6Jz1ViNQGKYWO6AfD9fpsp+YrLMh8IoVENzOeb1bPuJUWRbuGF0tVRpsa + 7aycL6A762hBJIvOf4d5zvvc2/nqf+xnwhgMgrYMEa0zTwvPYxGHP/qvscuOx8CKexZsy4PzFcYQ + 3fmbHEJpCqMz3GujVXJK/gAAAP//AwAyhdwOLwMAAA== + headers: + CF-Cache-Status: + - DYNAMIC + CF-RAY: + - 8f016eb36bb3a240-YYZ + Connection: + - keep-alive + Content-Encoding: + - gzip + Content-Type: + - application/json + Date: + - Wed, 11 Dec 2024 00:42:02 GMT + Server: + - cloudflare + Set-Cookie: + - __cf_bm=Q_ATX8JU4jFqXJPdwlneOua9wmNmAaASyAfcbPyPqng-1733877722-1.0.1.1-eTMEvBW7oqQa2i3l.Or2I3LF_cCESxfseq.S9DBr8dAJWsVoFfPxKtr5vMaO6yj4hRW8XOSOHcgIcwwqbHrLbg; + path=/; expires=Wed, 11-Dec-24 01:12:02 GMT; domain=.api.openai.com; HttpOnly; + Secure; SameSite=None + - _cfuvid=2ak.tRpn6uEHbM8GrWy_ALtrN34jVSNIJI1mFG2etvM-1733877722703-0.0.1.1-604800000; + path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None + Transfer-Encoding: + - chunked + X-Content-Type-Options: + - nosniff + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + openai-organization: + - wandb + openai-processing-ms: + - '476' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + x-ratelimit-limit-requests: + - '10000' + x-ratelimit-limit-tokens: + - '30000000' + x-ratelimit-remaining-requests: + - '9999' + x-ratelimit-remaining-tokens: + - '29999979' + x-ratelimit-reset-requests: + - 6ms + x-ratelimit-reset-tokens: + - 0s + x-request-id: + - req_52e061e1cc55cdd8847a7ba9342f1a14 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/integrations/openai/cassettes/test_autopatch/test_passthrough_op_kwargs.yaml b/tests/integrations/openai/cassettes/test_autopatch/test_passthrough_op_kwargs.yaml new file mode 100644 index 00000000000..646c57c6123 --- /dev/null +++ b/tests/integrations/openai/cassettes/test_autopatch/test_passthrough_op_kwargs.yaml @@ -0,0 +1,102 @@ +interactions: +- request: + body: '{"messages":[{"role":"user","content":"tell me a joke"}],"model":"gpt-4o"}' + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate, zstd + connection: + - keep-alive + content-length: + - '74' + content-type: + - application/json + host: + - api.openai.com + user-agent: + - OpenAI/Python 1.57.2 + x-stainless-arch: + - arm64 + x-stainless-async: + - 'false' + x-stainless-lang: + - python + x-stainless-os: + - MacOS + x-stainless-package-version: + - 1.57.2 + x-stainless-retry-count: + - '0' + x-stainless-runtime: + - CPython + x-stainless-runtime-version: + - 3.13.0rc2 + method: POST + uri: https://api.openai.com/v1/chat/completions + response: + body: + string: !!binary | + H4sIAAAAAAAAAwAAAP//jFLLbtswELzrK7a89GIVsmLXsS9Fr0UvBQIERVMINLkS2VBcglwVcQL/ + e0H5IQVNgV54mNkZzCz3pQAQVosdCGUkqz648rNer7G2z3fL8ITq+X6lvn7x/SHhN/d9LxZZQftf + qPii+qCoDw7Zkj/RKqJkzK7Lzc3N7WazqeuR6Emjy7IucLmisq7qVVndltXHs9CQVZjEDn4UAAAv + 45sjeo1PYgfV4oL0mJLsUOyuQwAiksuIkCnZxNKzWEykIs/ox9T35gCa/HuG9IgOmXyC1naGAaUy + QGwwfnrwD/7O4GXSyN8IbBC6gdO7uXHEdkgy9/KDc2f8eE3qqAuR9unMX/HWeptME1Em8jlVYgpi + ZI8FwM9xI8OrkiJE6gM3TI/os+FyebIT0xfMyNWZZGLpJrxeL95wazSytC7NNiqUVAb1pJzWLwdt + aUYUs85/h3nL+9Tb+u5/7CdCKQyMugkRtVWvC09jEfOB/mvsuuMxsEiHxNg3rfUdxhDt6Uba0Gz1 + ulJ6u632ojgWfwAAAP//AwCOwDMjLAMAAA== + headers: + CF-Cache-Status: + - DYNAMIC + CF-RAY: + - 8f016eb76b71ac9a-YYZ + Connection: + - keep-alive + Content-Encoding: + - gzip + Content-Type: + - application/json + Date: + - Wed, 11 Dec 2024 00:42:03 GMT + Server: + - cloudflare + Set-Cookie: + - __cf_bm=r.xSSsYQNFPvMiizFSvjQiecNA6Q1wQa0VR1YElfXi4-1733877723-1.0.1.1-GVW0i7wrpHCQSY5eXu7sIQgxYWl6jfeSordQ7JFxV3lO6UfFhwxRT92bBP4DfnrSYpBpRw3k4aONAURyvKctiQ; + path=/; expires=Wed, 11-Dec-24 01:12:03 GMT; domain=.api.openai.com; HttpOnly; + Secure; SameSite=None + - _cfuvid=CQJVOdASzL9ency5_q6SDaInTsvpjA240cIxf.AUwXM-1733877723385-0.0.1.1-604800000; + path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None + Transfer-Encoding: + - chunked + X-Content-Type-Options: + - nosniff + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + openai-organization: + - wandb + openai-processing-ms: + - '523' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + x-ratelimit-limit-requests: + - '10000' + x-ratelimit-limit-tokens: + - '30000000' + x-ratelimit-remaining-requests: + - '9999' + x-ratelimit-remaining-tokens: + - '29999979' + x-ratelimit-reset-requests: + - 6ms + x-ratelimit-reset-tokens: + - 0s + x-request-id: + - req_c9c57cfa6f37a99aaf0abac013237ed6 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/integrations/openai/test_autopatch.py b/tests/integrations/openai/test_autopatch.py new file mode 100644 index 00000000000..2c2f5201d3f --- /dev/null +++ b/tests/integrations/openai/test_autopatch.py @@ -0,0 +1,116 @@ +# This is included here for convenience. Instead of creating a dummy API, we can test +# autopatching against the actual OpenAI API. + +from typing import Any + +import pytest +from openai import OpenAI + +from weave.integrations.openai import openai_sdk +from weave.trace.autopatch import AutopatchSettings, IntegrationSettings, OpSettings + + +@pytest.mark.skip_clickhouse_client # TODO:VCR recording does not seem to allow us to make requests to the clickhouse db in non-recording mode +@pytest.mark.vcr( + filter_headers=["authorization"], allowed_hosts=["api.wandb.ai", "localhost"] +) +def test_disabled_integration_doesnt_patch(client_creator): + autopatch_settings = AutopatchSettings( + openai=IntegrationSettings(enabled=False), + ) + + with client_creator(autopatch_settings=autopatch_settings) as client: + oaiclient = OpenAI() + oaiclient.chat.completions.create( + model="gpt-4o", + messages=[{"role": "user", "content": "tell me a joke"}], + ) + + calls = list(client.get_calls()) + assert len(calls) == 0 + + +@pytest.mark.skip_clickhouse_client # TODO:VCR recording does not seem to allow us to make requests to the clickhouse db in non-recording mode +@pytest.mark.vcr( + filter_headers=["authorization"], allowed_hosts=["api.wandb.ai", "localhost"] +) +def test_enabled_integration_patches(client_creator): + autopatch_settings = AutopatchSettings( + openai=IntegrationSettings(enabled=True), + ) + + with client_creator(autopatch_settings=autopatch_settings) as client: + oaiclient = OpenAI() + oaiclient.chat.completions.create( + model="gpt-4o", + messages=[{"role": "user", "content": "tell me a joke"}], + ) + + calls = list(client.get_calls()) + assert len(calls) == 1 + + +@pytest.mark.skip_clickhouse_client # TODO:VCR recording does not seem to allow us to make requests to the clickhouse db in non-recording mode +@pytest.mark.vcr( + filter_headers=["authorization"], allowed_hosts=["api.wandb.ai", "localhost"] +) +def test_passthrough_op_kwargs(client_creator): + def redact_inputs(inputs: dict[str, Any]) -> dict[str, Any]: + return dict.fromkeys(inputs, "REDACTED") + + autopatch_settings = AutopatchSettings( + openai=IntegrationSettings( + op_settings=OpSettings( + postprocess_inputs=redact_inputs, + ) + ) + ) + + # Explicitly reset the patcher here to pretend like we're starting fresh. We need + # to do this because `_openai_patcher` is a global variable that is shared across + # tests. If we don't reset it, it will retain the state from the previous test, + # which can cause this test to fail. + openai_sdk._openai_patcher = None + + with client_creator(autopatch_settings=autopatch_settings) as client: + oaiclient = OpenAI() + oaiclient.chat.completions.create( + model="gpt-4o", + messages=[{"role": "user", "content": "tell me a joke"}], + ) + + calls = list(client.get_calls()) + assert len(calls) == 1 + + call = calls[0] + assert all(v == "REDACTED" for v in call.inputs.values()) + + +@pytest.mark.skip_clickhouse_client # TODO:VCR recording does not seem to allow us to make requests to the clickhouse db in non-recording mode +@pytest.mark.vcr( + filter_headers=["authorization"], allowed_hosts=["api.wandb.ai", "localhost"] +) +def test_configuration_with_dicts(client_creator): + def redact_inputs(inputs: dict[str, Any]) -> dict[str, Any]: + return dict.fromkeys(inputs, "REDACTED") + + autopatch_settings = { + "openai": { + "op_settings": {"postprocess_inputs": redact_inputs}, + } + } + + openai_sdk._openai_patcher = None + + with client_creator(autopatch_settings=autopatch_settings) as client: + oaiclient = OpenAI() + oaiclient.chat.completions.create( + model="gpt-4o", + messages=[{"role": "user", "content": "tell me a joke"}], + ) + + calls = list(client.get_calls()) + assert len(calls) == 1 + + call = calls[0] + assert all(v == "REDACTED" for v in call.inputs.values()) diff --git a/weave/integrations/openai/openai_sdk.py b/weave/integrations/openai/openai_sdk.py index 7814700d4d3..a1e3a9b5831 100644 --- a/weave/integrations/openai/openai_sdk.py +++ b/weave/integrations/openai/openai_sdk.py @@ -1,15 +1,20 @@ +from __future__ import annotations + import importlib from functools import wraps -from typing import TYPE_CHECKING, Any, Callable, Optional +from typing import TYPE_CHECKING, Any, Callable import weave +from weave.trace.autopatch import IntegrationSettings, OpSettings from weave.trace.op import Op, ProcessedInputs from weave.trace.op_extensions.accumulator import add_accumulator -from weave.trace.patcher import MultiPatcher, SymbolPatcher +from weave.trace.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher if TYPE_CHECKING: from openai.types.chat import ChatCompletionChunk +_openai_patcher: MultiPatcher | None = None + def maybe_unwrap_api_response(value: Any) -> Any: """If the caller requests a raw response, we unwrap the APIResponse object. @@ -43,9 +48,7 @@ def maybe_unwrap_api_response(value: Any) -> Any: return value -def openai_on_finish_post_processor( - value: Optional["ChatCompletionChunk"], -) -> Optional[dict]: +def openai_on_finish_post_processor(value: ChatCompletionChunk | None) -> dict | None: from openai.types.chat import ChatCompletion, ChatCompletionChunk from openai.types.chat.chat_completion_chunk import ( ChoiceDeltaFunctionCall, @@ -60,8 +63,8 @@ def openai_on_finish_post_processor( value = maybe_unwrap_api_response(value) def _get_function_call( - function_call: Optional[ChoiceDeltaFunctionCall], - ) -> Optional[FunctionCall]: + function_call: ChoiceDeltaFunctionCall | None, + ) -> FunctionCall | None: if function_call is None: return function_call if isinstance(function_call, ChoiceDeltaFunctionCall): @@ -73,8 +76,8 @@ def _get_function_call( return None def _get_tool_calls( - tool_calls: Optional[list[ChoiceDeltaToolCall]], - ) -> Optional[list[ChatCompletionMessageToolCall]]: + tool_calls: list[ChoiceDeltaToolCall] | None, + ) -> list[ChatCompletionMessageToolCall] | None: if tool_calls is None: return tool_calls @@ -128,10 +131,10 @@ def _get_tool_calls( def openai_accumulator( - acc: Optional["ChatCompletionChunk"], - value: "ChatCompletionChunk", + acc: ChatCompletionChunk | None, + value: ChatCompletionChunk, skip_last: bool = False, -) -> "ChatCompletionChunk": +) -> ChatCompletionChunk: from openai.types.chat import ChatCompletionChunk from openai.types.chat.chat_completion_chunk import ( ChoiceDeltaFunctionCall, @@ -285,7 +288,7 @@ def should_use_accumulator(inputs: dict) -> bool: def openai_on_input_handler( func: Op, args: tuple, kwargs: dict -) -> Optional[ProcessedInputs]: +) -> ProcessedInputs | None: if len(args) == 2 and isinstance(args[1], weave.EasyPrompt): original_args = args original_kwargs = kwargs @@ -305,20 +308,16 @@ def openai_on_input_handler( return None -def create_wrapper_sync( - name: str, -) -> Callable[[Callable], Callable]: +def create_wrapper_sync(settings: OpSettings) -> Callable[[Callable], Callable]: def wrapper(fn: Callable) -> Callable: "We need to do this so we can check if `stream` is used" def _add_stream_options(fn: Callable) -> Callable: @wraps(fn) def _wrapper(*args: Any, **kwargs: Any) -> Any: - if bool(kwargs.get("stream")) and kwargs.get("stream_options") is None: + if kwargs.get("stream") and kwargs.get("stream_options") is None: kwargs["stream_options"] = {"include_usage": True} - return fn( - *args, **kwargs - ) # This is where the final execution of fn is happening. + return fn(*args, **kwargs) return _wrapper @@ -327,8 +326,8 @@ def _openai_stream_options_is_set(inputs: dict) -> bool: return True return False - op = weave.op()(_add_stream_options(fn)) - op.name = name # type: ignore + op_kwargs = settings.model_dump() + op = weave.op(_add_stream_options(fn), **op_kwargs) op._set_on_input_handler(openai_on_input_handler) return add_accumulator( op, # type: ignore @@ -345,16 +344,14 @@ def _openai_stream_options_is_set(inputs: dict) -> bool: # Surprisingly, the async `client.chat.completions.create` does not pass # `inspect.iscoroutinefunction`, so we can't dispatch on it and must write # it manually here... -def create_wrapper_async( - name: str, -) -> Callable[[Callable], Callable]: +def create_wrapper_async(settings: OpSettings) -> Callable[[Callable], Callable]: def wrapper(fn: Callable) -> Callable: "We need to do this so we can check if `stream` is used" def _add_stream_options(fn: Callable) -> Callable: @wraps(fn) async def _wrapper(*args: Any, **kwargs: Any) -> Any: - if bool(kwargs.get("stream")) and kwargs.get("stream_options") is None: + if kwargs.get("stream") and kwargs.get("stream_options") is None: kwargs["stream_options"] = {"include_usage": True} return await fn(*args, **kwargs) @@ -365,8 +362,8 @@ def _openai_stream_options_is_set(inputs: dict) -> bool: return True return False - op = weave.op()(_add_stream_options(fn)) - op.name = name # type: ignore + op_kwargs = settings.model_dump() + op = weave.op(_add_stream_options(fn), **op_kwargs) op._set_on_input_handler(openai_on_input_handler) return add_accumulator( op, # type: ignore @@ -380,28 +377,61 @@ def _openai_stream_options_is_set(inputs: dict) -> bool: return wrapper -symbol_patchers = [ - # Patch the Completions.create method - SymbolPatcher( - lambda: importlib.import_module("openai.resources.chat.completions"), - "Completions.create", - create_wrapper_sync(name="openai.chat.completions.create"), - ), - SymbolPatcher( - lambda: importlib.import_module("openai.resources.chat.completions"), - "AsyncCompletions.create", - create_wrapper_async(name="openai.chat.completions.create"), - ), - SymbolPatcher( - lambda: importlib.import_module("openai.resources.beta.chat.completions"), - "Completions.parse", - create_wrapper_sync(name="openai.beta.chat.completions.parse"), - ), - SymbolPatcher( - lambda: importlib.import_module("openai.resources.beta.chat.completions"), - "AsyncCompletions.parse", - create_wrapper_async(name="openai.beta.chat.completions.parse"), - ), -] - -openai_patcher = MultiPatcher(symbol_patchers) # type: ignore +def get_openai_patcher( + settings: IntegrationSettings | None = None, +) -> MultiPatcher | NoOpPatcher: + if settings is None: + settings = IntegrationSettings() + + if not settings.enabled: + return NoOpPatcher() + + global _openai_patcher + if _openai_patcher is not None: + return _openai_patcher + + base = settings.op_settings + + completions_create_settings = base.model_copy( + update={"name": base.name or "openai.chat.completions.create"} + ) + async_completions_create_settings = base.model_copy( + update={"name": base.name or "openai.chat.completions.create"} + ) + completions_parse_settings = base.model_copy( + update={"name": base.name or "openai.beta.chat.completions.parse"} + ) + async_completions_parse_settings = base.model_copy( + update={"name": base.name or "openai.beta.chat.completions.parse"} + ) + + _openai_patcher = MultiPatcher( + [ + SymbolPatcher( + lambda: importlib.import_module("openai.resources.chat.completions"), + "Completions.create", + create_wrapper_sync(settings=completions_create_settings), + ), + SymbolPatcher( + lambda: importlib.import_module("openai.resources.chat.completions"), + "AsyncCompletions.create", + create_wrapper_async(settings=async_completions_create_settings), + ), + SymbolPatcher( + lambda: importlib.import_module( + "openai.resources.beta.chat.completions" + ), + "Completions.parse", + create_wrapper_sync(settings=completions_parse_settings), + ), + SymbolPatcher( + lambda: importlib.import_module( + "openai.resources.beta.chat.completions" + ), + "AsyncCompletions.parse", + create_wrapper_async(settings=async_completions_parse_settings), + ), + ] + ) + + return _openai_patcher diff --git a/weave/scorers/llm_utils.py b/weave/scorers/llm_utils.py index 68ae2ccb366..eef6f018b0f 100644 --- a/weave/scorers/llm_utils.py +++ b/weave/scorers/llm_utils.py @@ -2,10 +2,6 @@ from typing import TYPE_CHECKING, Any, Union -from weave.trace.autopatch import autopatch - -autopatch() # ensure both weave patching and instructor patching are applied - OPENAI_DEFAULT_MODEL = "gpt-4o" OPENAI_DEFAULT_EMBEDDING_MODEL = "text-embedding-3-small" OPENAI_DEFAULT_MODERATION_MODEL = "text-moderation-latest" diff --git a/weave/trace/api.py b/weave/trace/api.py index ee8131b0875..294308cbb67 100644 --- a/weave/trace/api.py +++ b/weave/trace/api.py @@ -13,6 +13,7 @@ # There is probably a better place for this, but including here for now to get the fix in. from weave import type_handlers # noqa: F401 from weave.trace import urls, util, weave_client, weave_init +from weave.trace.autopatch import AutopatchSettings from weave.trace.constants import TRACE_OBJECT_EMOJI from weave.trace.context import call_context from weave.trace.context import weave_client_context as weave_client_context @@ -32,6 +33,7 @@ def init( project_name: str, *, settings: UserSettings | dict[str, Any] | None = None, + autopatch_settings: AutopatchSettings | None = None, ) -> weave_client.WeaveClient: """Initialize weave tracking, logging to a wandb project. @@ -52,7 +54,12 @@ def init( if should_disable_weave(): return weave_init.init_weave_disabled().client - return weave_init.init_weave(project_name).client + initialized_client = weave_init.init_weave( + project_name, + autopatch_settings=autopatch_settings, + ) + + return initialized_client.client @contextlib.contextmanager diff --git a/weave/trace/autopatch.py b/weave/trace/autopatch.py index 3a5dca14556..0619194a224 100644 --- a/weave/trace/autopatch.py +++ b/weave/trace/autopatch.py @@ -4,8 +4,54 @@ check if libraries are installed and imported and patch in the case that they are. """ +from typing import Any, Callable, Optional, Union -def autopatch() -> None: +from pydantic import BaseModel, Field, validate_call + +from weave.trace.weave_client import Call + + +class OpSettings(BaseModel): + """Op settings for a specific integration. + These currently subset the `op` decorator args to provide a consistent interface + when working with auto-patched functions. See the `op` decorator for more details.""" + + name: Optional[str] = None + call_display_name: Optional[Union[str, Callable[[Call], str]]] = None + postprocess_inputs: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None + postprocess_output: Optional[Callable[[Any], Any]] = None + + +class IntegrationSettings(BaseModel): + """Configuration for a specific integration.""" + + enabled: bool = True + op_settings: OpSettings = Field(default_factory=OpSettings) + + +class AutopatchSettings(BaseModel): + """Settings for auto-patching integrations.""" + + # These will be uncommented as we add support for more integrations. Note that + + # anthropic: IntegrationSettings = Field(default_factory=IntegrationSettings) + # cerebras: IntegrationSettings = Field(default_factory=IntegrationSettings) + # cohere: IntegrationSettings = Field(default_factory=IntegrationSettings) + # dspy: IntegrationSettings = Field(default_factory=IntegrationSettings) + # google_ai_studio: IntegrationSettings = Field(default_factory=IntegrationSettings) + # groq: IntegrationSettings = Field(default_factory=IntegrationSettings) + # instructor: IntegrationSettings = Field(default_factory=IntegrationSettings) + # langchain: IntegrationSettings = Field(default_factory=IntegrationSettings) + # litellm: IntegrationSettings = Field(default_factory=IntegrationSettings) + # llamaindex: IntegrationSettings = Field(default_factory=IntegrationSettings) + # mistral: IntegrationSettings = Field(default_factory=IntegrationSettings) + # notdiamond: IntegrationSettings = Field(default_factory=IntegrationSettings) + openai: IntegrationSettings = Field(default_factory=IntegrationSettings) + # vertexai: IntegrationSettings = Field(default_factory=IntegrationSettings) + + +@validate_call +def autopatch(settings: Optional[AutopatchSettings] = None) -> None: from weave.integrations.anthropic.anthropic_sdk import anthropic_patcher from weave.integrations.cerebras.cerebras_sdk import cerebras_patcher from weave.integrations.cohere.cohere_sdk import cohere_patcher @@ -20,10 +66,13 @@ def autopatch() -> None: from weave.integrations.llamaindex.llamaindex import llamaindex_patcher from weave.integrations.mistral import mistral_patcher from weave.integrations.notdiamond.tracing import notdiamond_patcher - from weave.integrations.openai.openai_sdk import openai_patcher + from weave.integrations.openai.openai_sdk import get_openai_patcher from weave.integrations.vertexai.vertexai_sdk import vertexai_patcher - openai_patcher.attempt_patch() + if settings is None: + settings = AutopatchSettings() + + get_openai_patcher(settings.openai).attempt_patch() mistral_patcher.attempt_patch() litellm_patcher.attempt_patch() llamaindex_patcher.attempt_patch() @@ -54,10 +103,10 @@ def reset_autopatch() -> None: from weave.integrations.llamaindex.llamaindex import llamaindex_patcher from weave.integrations.mistral import mistral_patcher from weave.integrations.notdiamond.tracing import notdiamond_patcher - from weave.integrations.openai.openai_sdk import openai_patcher + from weave.integrations.openai.openai_sdk import get_openai_patcher from weave.integrations.vertexai.vertexai_sdk import vertexai_patcher - openai_patcher.undo_patch() + get_openai_patcher().undo_patch() mistral_patcher.undo_patch() litellm_patcher.undo_patch() llamaindex_patcher.undo_patch() diff --git a/weave/trace/patcher.py b/weave/trace/patcher.py index 1567c4e2bb9..c1d0d653ffa 100644 --- a/weave/trace/patcher.py +++ b/weave/trace/patcher.py @@ -17,6 +17,14 @@ def undo_patch(self) -> bool: raise NotImplementedError() +class NoOpPatcher(Patcher): + def attempt_patch(self) -> bool: + return True + + def undo_patch(self) -> bool: + return True + + class MultiPatcher(Patcher): def __init__(self, patchers: Sequence[Patcher]) -> None: self.patchers = patchers diff --git a/weave/trace/weave_init.py b/weave/trace/weave_init.py index 563dcbdaed4..f51d42d5018 100644 --- a/weave/trace/weave_init.py +++ b/weave/trace/weave_init.py @@ -63,7 +63,9 @@ def get_entity_project_from_project_name(project_name: str) -> tuple[str, str]: def init_weave( - project_name: str, ensure_project_exists: bool = True + project_name: str, + ensure_project_exists: bool = True, + autopatch_settings: autopatch.AutopatchSettings | None = None, ) -> InitializedClient: global _current_inited_client if _current_inited_client is not None: @@ -120,7 +122,7 @@ def init_weave( # autopatching is only supported for the wandb client, because OpenAI calls are not # logged in local mode currently. When that's fixed, this autopatch call can be # moved to InitializedClient.__init__ - autopatch.autopatch() + autopatch.autopatch(autopatch_settings) username = get_username() try: From 16e47c3a8d804db1f7c8c80fe53b5b58082a6757 Mon Sep 17 00:00:00 2001 From: Jamie Rasmussen <112953339+jamie-rasmussen@users.noreply.github.com> Date: Wed, 11 Dec 2024 16:19:24 -0600 Subject: [PATCH 40/52] chore(ui): update UUID dependency to v11 (latest) (#3208) --- weave-js/package.json | 3 +-- weave-js/yarn.lock | 15 +++++---------- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/weave-js/package.json b/weave-js/package.json index cb57125143a..1f551021ed4 100644 --- a/weave-js/package.json +++ b/weave-js/package.json @@ -192,7 +192,6 @@ "@types/react-virtualized-auto-sizer": "^1.0.0", "@types/safe-json-stringify": "^1.1.2", "@types/styled-components": "^5.1.26", - "@types/uuid": "^9.0.1", "@types/wavesurfer.js": "^2.0.0", "@types/zen-observable": "^0.8.3", "@typescript-eslint/eslint-plugin": "5.35.1", @@ -237,7 +236,7 @@ "tslint-config-prettier": "^1.18.0", "tslint-plugin-prettier": "^2.3.0", "typescript": "4.7.4", - "uuid": "^9.0.0", + "uuid": "^11.0.3", "vite": "5.2.9", "vitest": "^1.6.0" }, diff --git a/weave-js/yarn.lock b/weave-js/yarn.lock index c7f9379e32a..6a5ec14e872 100644 --- a/weave-js/yarn.lock +++ b/weave-js/yarn.lock @@ -4776,11 +4776,6 @@ resolved "https://registry.yarnpkg.com/@types/unist/-/unist-2.0.7.tgz#5b06ad6894b236a1d2bd6b2f07850ca5c59cf4d6" integrity sha512-cputDpIbFgLUaGQn6Vqg3/YsJwxUwHLO13v3i5ouxT4lat0khip9AEWxtERujXV9wxIB1EyF97BSJFt6vpdI8g== -"@types/uuid@^9.0.1": - version "9.0.2" - resolved "https://registry.yarnpkg.com/@types/uuid/-/uuid-9.0.2.tgz#ede1d1b1e451548d44919dc226253e32a6952c4b" - integrity sha512-kNnC1GFBLuhImSnV7w4njQkUiJi0ZXUycu1rUaouPqiKlXkh77JKgdRnTAp1x5eBwcIwbtI+3otwzuIDEuDoxQ== - "@types/wavesurfer.js@^2.0.0": version "2.0.2" resolved "https://registry.yarnpkg.com/@types/wavesurfer.js/-/wavesurfer.js-2.0.2.tgz#b98a4d57ca24ee2028ae6dd5c2208b568bb73842" @@ -15032,6 +15027,11 @@ util-deprecate@^1.0.1, util-deprecate@^1.0.2, util-deprecate@~1.0.1: resolved "https://registry.yarnpkg.com/util-deprecate/-/util-deprecate-1.0.2.tgz#450d4dc9fa70de732762fbd2d4a28981419a0ccf" integrity sha512-EPD5q1uXyFxJpCrLnCc1nHnq3gOa6DZBocAIiI2TaSCA7VCJ1UJDMagCzIkXNsUYfD1daK//LTEQ8xiIbrHtcw== +uuid@^11.0.3: + version "11.0.3" + resolved "https://registry.yarnpkg.com/uuid/-/uuid-11.0.3.tgz#248451cac9d1a4a4128033e765d137e2b2c49a3d" + integrity sha512-d0z310fCWv5dJwnX1Y/MncBAqGMKEzlBb1AOf7z9K8ALnd0utBX/msg/fA0+sbyN1ihbMsLhrBlnl1ak7Wa0rg== + uuid@^2.0.2: version "2.0.3" resolved "https://registry.yarnpkg.com/uuid/-/uuid-2.0.3.tgz#67e2e863797215530dff318e5bf9dcebfd47b21a" @@ -15042,11 +15042,6 @@ uuid@^3.0.0, uuid@^3.4.0: resolved "https://registry.yarnpkg.com/uuid/-/uuid-3.4.0.tgz#b23e4358afa8a202fe7a100af1f5f883f02007ee" integrity sha512-HjSDRw6gZE5JMggctHBcjVak08+KEVhSIiDzFnT9S9aegmp85S/bReBVTb4QTFaRNptJ9kuYaNhnbNEOkbKb/A== -uuid@^9.0.0: - version "9.0.0" - resolved "https://registry.yarnpkg.com/uuid/-/uuid-9.0.0.tgz#592f550650024a38ceb0c562f2f6aa435761efb5" - integrity sha512-MXcSTerfPa4uqyzStbRoTgt5XIe3x5+42+q1sDuy3R5MDk66URdLMOZe5aPX/SQd+kuYAh0FdP/pO28IkQyTeg== - uvu@^0.5.0: version "0.5.6" resolved "https://registry.yarnpkg.com/uvu/-/uvu-0.5.6.tgz#2754ca20bcb0bb59b64e9985e84d2e81058502df" From 444c04dcdeb965b531d7508f5e82af56cfcb00f2 Mon Sep 17 00:00:00 2001 From: Jamie Rasmussen <112953339+jamie-rasmussen@users.noreply.github.com> Date: Wed, 11 Dec 2024 18:50:14 -0600 Subject: [PATCH 41/52] chore(ui): remove some unused code (#3157) --- .../PagePanelComponents/Home/Browse3.tsx | 138 +----------------- 1 file changed, 1 insertion(+), 137 deletions(-) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx index 3517f4d3b9c..761fd536930 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx @@ -1,15 +1,5 @@ import {ApolloProvider} from '@apollo/client'; -import {Home} from '@mui/icons-material'; -import { - AppBar, - Box, - Breadcrumbs, - Drawer, - IconButton, - Link as MaterialLink, - Toolbar, - Typography, -} from '@mui/material'; +import {Box, Drawer} from '@mui/material'; import { GridColumnVisibilityModel, GridFilterModel, @@ -21,9 +11,7 @@ import {LicenseInfo} from '@mui/x-license'; import {makeGorillaApolloClient} from '@wandb/weave/apollo'; import {EVALUATE_OP_NAME_POST_PYDANTIC} from '@wandb/weave/components/PagePanelComponents/Home/Browse3/pages/common/heuristics'; import {opVersionKeyToRefUri} from '@wandb/weave/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/utilities'; -import _ from 'lodash'; import React, { - ComponentProps, FC, useCallback, useEffect, @@ -33,7 +21,6 @@ import React, { } from 'react'; import useMousetrap from 'react-hook-mousetrap'; import { - Link as RouterLink, Redirect, Route, Switch, @@ -199,7 +186,6 @@ export const Browse3: FC<{ `/${URL_BROWSE3}`, ]}> @@ -211,7 +197,6 @@ export const Browse3: FC<{ }; const Browse3Mounted: FC<{ - hideHeader?: boolean; headerOffset?: number; navigateAwayFromProject?: () => void; }> = props => { @@ -225,37 +210,6 @@ const Browse3Mounted: FC<{ overflow: 'auto', flexDirection: 'column', }}> - {!props.hideHeader && ( - theme.zIndex.drawer + 1, - height: '60px', - flex: '0 0 auto', - position: 'static', - }}> - - - theme.palette.getContrastText(theme.palette.primary.main), - '&:hover': { - color: theme => - theme.palette.getContrastText(theme.palette.primary.dark), - }, - marginRight: theme => theme.spacing(2), - }}> - - - - - - )} @@ -1050,20 +1004,6 @@ const ComparePageBinding = () => { return ; }; -const AppBarLink = (props: ComponentProps) => ( - theme.palette.getContrastText(theme.palette.primary.main), - '&:hover': { - color: theme => - theme.palette.getContrastText(theme.palette.primary.dark), - }, - }} - {...props} - component={RouterLink} - /> -); - const PlaygroundPageBinding = () => { const params = useParamsDecoded(); return ( @@ -1074,79 +1014,3 @@ const PlaygroundPageBinding = () => { /> ); }; - -const Browse3Breadcrumbs: FC = props => { - const params = useParamsDecoded(); - const query = useURLSearchParamsDict(); - const filePathParts = query.path?.split('/') ?? []; - const refFields = query.extra?.split('/') ?? []; - - return ( - - {params.entity && ( - - {params.entity} - - )} - {params.project && ( - - {params.project} - - )} - {params.tab && ( - - {params.tab} - - )} - {params.itemName && ( - - {params.itemName} - - )} - {params.version && ( - - {params.version} - - )} - {filePathParts.map((part, idx) => ( - - {part} - - ))} - {_.range(0, refFields.length, 2).map(idx => ( - - - theme.palette.getContrastText(theme.palette.primary.main), - }}> - {refFields[idx]} - - - {refFields[idx + 1]} - - - ))} - - ); -}; From 96d1d0d0f48cd571ae7b7737de2fd58d663588d9 Mon Sep 17 00:00:00 2001 From: Griffin Tarpenning Date: Wed, 11 Dec 2024 17:55:26 -0800 Subject: [PATCH 42/52] chore(ui): Create scorer drawer style + small annotation drawer style tweaks (#3186) --- .../StructuredFeedback/HumanAnnotation.tsx | 4 +- .../ScorersPage/AnnotationScorerForm.tsx | 9 +- .../pages/ScorersPage/FormComponents.tsx | 2 +- .../pages/ScorersPage/NewScorerDrawer.tsx | 25 +++- .../pages/ScorersPage/ZodSchemaForm.tsx | 133 +++++++++++------- 5 files changed, 111 insertions(+), 62 deletions(-) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/StructuredFeedback/HumanAnnotation.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/StructuredFeedback/HumanAnnotation.tsx index 7facffe9556..2821c02affa 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/StructuredFeedback/HumanAnnotation.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/StructuredFeedback/HumanAnnotation.tsx @@ -415,7 +415,7 @@ export const TextFeedbackColumn = ({ placeholder="" /> {maxLength && ( -
+
{`Maximum characters: ${maxLength}`}
)} @@ -603,7 +603,7 @@ export const NumericalTextField: React.FC = ({ errorState={error} /> {(min != null || max != null) && ( -
+
{isInteger ? 'Integer required. ' : ''} {min != null && `Min: ${min}`} {min != null && max != null && ', '} diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/AnnotationScorerForm.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/AnnotationScorerForm.tsx index 9acbdfe6c2f..a478437facb 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/AnnotationScorerForm.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/AnnotationScorerForm.tsx @@ -1,5 +1,5 @@ import {Box} from '@material-ui/core'; -import React, {FC, useCallback, useState} from 'react'; +import React, {FC, useCallback, useEffect, useState} from 'react'; import {z} from 'zod'; import {createBaseObjectInstance} from '../wfReactInterface/baseObjectClassQuery'; @@ -28,7 +28,7 @@ const AnnotationScorerFormSchema = z.object({ }), z.object({ type: z.literal('String'), - 'Max length': z.number().optional(), + 'Maximum length': z.number().optional(), }), z.object({ type: z.literal('Select'), @@ -45,6 +45,9 @@ export const AnnotationScorerForm: FC< ScorerFormProps> > = ({data, onDataChange}) => { const [config, setConfig] = useState(data ?? DEFAULT_STATE); + useEffect(() => { + setConfig(data ?? DEFAULT_STATE); + }, [data]); const [isValid, setIsValid] = useState(false); const handleConfigChange = useCallback( @@ -113,7 +116,7 @@ function convertTypeExtrasToJsonSchema( const typeSchema = obj.Type; const typeExtras: Record = {}; if (typeSchema.type === 'String') { - typeExtras.maxLength = typeSchema['Max length']; + typeExtras.maxLength = typeSchema['Maximum length']; } else if (typeSchema.type === 'Integer' || typeSchema.type === 'Number') { typeExtras.minimum = typeSchema.Minimum; typeExtras.maximum = typeSchema.Maximum; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/FormComponents.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/FormComponents.tsx index 2716bfbfa81..250c896cfea 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/FormComponents.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/FormComponents.tsx @@ -3,7 +3,7 @@ import {Select} from '@wandb/weave/components/Form/Select'; import {TextField} from '@wandb/weave/components/Form/TextField'; import React from 'react'; -export const GAP_BETWEEN_ITEMS_PX = 10; +export const GAP_BETWEEN_ITEMS_PX = 16; export const GAP_BETWEEN_LABEL_AND_FIELD_PX = 10; type AutocompleteWithLabelType