From a7170b4c1e6dbca9b279865b588d4edfee34468f Mon Sep 17 00:00:00 2001 From: Jamie Rasmussen <112953339+jamie-rasmussen@users.noreply.github.com> Date: Tue, 10 Dec 2024 11:00:13 -0600 Subject: [PATCH 01/12] feat(ui): Support Anthropic calls in Chat View (#3187) --- .../Home/Browse3/pages/ChatView/hooks.ts | 88 +++++++++++++++++++ 1 file changed, 88 insertions(+) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/hooks.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/hooks.ts index 38b5c820195..33ced58ec49 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/hooks.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/hooks.ts @@ -246,6 +246,63 @@ export const isTraceCallChatFormatGemini = (call: TraceCallSchema): boolean => { ); }; +export const isAnthropicContentBlock = (block: any): boolean => { + if (!_.isPlainObject(block)) { + return false; + } + // TODO: Are there other types? + if (block.type !== 'text') { + return false; + } + if (!hasStringProp(block, 'text')) { + return false; + } + return true; +}; + +export const isAnthropicCompletionFormat = (output: any): boolean => { + if (output !== null) { + // TODO: Could have additional checks here on things like usage + if ( + _.isPlainObject(output) && + output.type === 'message' && + output.role === 'assistant' && + hasStringProp(output, 'model') && + _.isArray(output.content) && + output.content.every((c: any) => isAnthropicContentBlock(c)) + ) { + return true; + } + return false; + } + return true; +}; + +type AnthropicContentBlock = { + type: 'text'; + text: string; +}; + +export const anthropicContentBlocksToChoices = ( + blocks: AnthropicContentBlock[], + stopReason: string +): Choice[] => { + const choices: Choice[] = []; + for (let i = 0; i < blocks.length; i++) { + const block = blocks[i]; + choices.push({ + index: i, + message: { + role: 'assistant', + content: block.text, + }, + // TODO: What is correct way to map this? + finish_reason: stopReason, + }); + } + return choices; +}; + export const isTraceCallChatFormatOpenAI = (call: TraceCallSchema): boolean => { if (!('messages' in call.inputs)) { return false; @@ -336,6 +393,19 @@ export const normalizeChatRequest = (request: any): ChatRequest => { ], }; } + // Anthropic has system message as a top-level request field + if (hasStringProp(request, 'system')) { + return { + ...request, + messages: [ + { + role: 'system', + content: request.system, + }, + ...request.messages, + ], + }; + } return request as ChatRequest; }; @@ -360,6 +430,24 @@ export const normalizeChatCompletion = ( }, }; } + if (isAnthropicCompletionFormat(completion)) { + return { + id: completion.id, + choices: anthropicContentBlocksToChoices( + completion.content, + completion.stop_reason + ), + created: 0, + model: completion.model, + system_fingerprint: '', + usage: { + prompt_tokens: completion.usage.input_tokens, + completion_tokens: completion.usage.output_tokens, + total_tokens: + completion.usage.input_tokens + completion.usage.output_tokens, + }, + }; + } return completion as ChatCompletion; }; From 935f68cbfdd55a2f65a476b201d2227167416097 Mon Sep 17 00:00:00 2001 From: Connie Lee Date: Tue, 10 Dec 2024 09:28:08 -0800 Subject: [PATCH 02/12] feat(ui): Add ExpandingPill component (#3184) --- weave-js/src/components/Tag/Pill.tsx | 39 ++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/weave-js/src/components/Tag/Pill.tsx b/weave-js/src/components/Tag/Pill.tsx index 6f734f9cbfc..ed958a0c543 100644 --- a/weave-js/src/components/Tag/Pill.tsx +++ b/weave-js/src/components/Tag/Pill.tsx @@ -59,3 +59,42 @@ export const IconOnlyPill: FC = ({ ); }; + +export type ExpandingPillProps = { + className?: string; + color?: TagColorName; + icon: IconName; + label: string; +}; +export const ExpandingPill = ({ + className, + color, + icon, + label, +}: ExpandingPillProps) => { + const classes = useTagClasses({color, isInteractive: true}); + return ( + +
+ + + {label} + +
+
+ ); +}; From a14edbe6f6c37f3e9094b50ffdd2adc5bd3ecb65 Mon Sep 17 00:00:00 2001 From: Jamie Rasmussen <112953339+jamie-rasmussen@users.noreply.github.com> Date: Tue, 10 Dec 2024 11:52:04 -0600 Subject: [PATCH 03/12] chore(weave): ui_url error gives a hint about the problem (#3190) --- weave/trace/weave_client.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/weave/trace/weave_client.py b/weave/trace/weave_client.py index 86ba7b8653e..0eca3fcbedb 100644 --- a/weave/trace/weave_client.py +++ b/weave/trace/weave_client.py @@ -239,7 +239,9 @@ def func_name(self) -> str: @property def feedback(self) -> RefFeedbackQuery: if not self.id: - raise ValueError("Can't get feedback for call without ID") + raise ValueError( + "Can't get feedback for call without ID, was `weave.init` called?" + ) if self._feedback is None: try: @@ -253,7 +255,9 @@ def feedback(self) -> RefFeedbackQuery: @property def ui_url(self) -> str: if not self.id: - raise ValueError("Can't get URL for call without ID") + raise ValueError( + "Can't get URL for call without ID, was `weave.init` called?" + ) try: entity, project = self.project_id.split("/") @@ -265,7 +269,9 @@ def ui_url(self) -> str: def ref(self) -> CallRef: entity, project = self.project_id.split("/") if not self.id: - raise ValueError("Can't get ref for call without ID") + raise ValueError( + "Can't get ref for call without ID, was `weave.init` called?" + ) return CallRef(entity, project, self.id) @@ -273,7 +279,9 @@ def ref(self) -> CallRef: def children(self) -> CallsIter: client = weave_client_context.require_weave_client() if not self.id: - raise ValueError("Can't get children of call without ID") + raise ValueError( + "Can't get children of call without ID, was `weave.init` called?" + ) client = weave_client_context.require_weave_client() return CallsIter( From c19fe6bf246c058a3cbaa11bd544eb2e2fbbd24c Mon Sep 17 00:00:00 2001 From: brianlund-wandb Date: Tue, 10 Dec 2024 10:36:51 -0800 Subject: [PATCH 04/12] fix(app): render bounding boxes on media panel without points (#3158) * WB-16623 coercing empty point to default, targeting camera on bounding box when point cloud empty * code style cleanup --- .../src/common/util/SdkPointCloudToBabylon.test.ts | 14 ++++++++++++++ weave-js/src/common/util/SdkPointCloudToBabylon.ts | 2 +- weave-js/src/common/util/render_babylon.ts | 13 +++++++++++-- 3 files changed, 26 insertions(+), 3 deletions(-) diff --git a/weave-js/src/common/util/SdkPointCloudToBabylon.test.ts b/weave-js/src/common/util/SdkPointCloudToBabylon.test.ts index 95c7639dc4a..df5eca57b46 100644 --- a/weave-js/src/common/util/SdkPointCloudToBabylon.test.ts +++ b/weave-js/src/common/util/SdkPointCloudToBabylon.test.ts @@ -4,6 +4,7 @@ import { DEFAULT_POINT_COLOR, getFilteringOptionsForPointCloud, getVertexCompatiblePositionsAndColors, + loadPointCloud, MAX_BOUNDING_BOX_LABELS_FOR_DISPLAY, MaxAlphaValue, } from './SdkPointCloudToBabylon'; @@ -174,3 +175,16 @@ describe('getFilteringOptionsForPointCloud', () => { expect(newClassIdToLabel.get(49)).toEqual('label49'); }); }); +describe('loadPointCloud', () => { + it('appropriate defaults set when loading point cloud from file', () => { + const fileContents = JSON.stringify({ + boxes: [], + points: [[]], + type: 'lidar/beta', + vectors: [], + }); + const babylonPointCloud = loadPointCloud(fileContents); + expect(babylonPointCloud.points).toHaveLength(1); + expect(babylonPointCloud.points[0].position).toEqual([0, 0, 0]); + }); +}); diff --git a/weave-js/src/common/util/SdkPointCloudToBabylon.ts b/weave-js/src/common/util/SdkPointCloudToBabylon.ts index 274e1676be4..d52682743ee 100644 --- a/weave-js/src/common/util/SdkPointCloudToBabylon.ts +++ b/weave-js/src/common/util/SdkPointCloudToBabylon.ts @@ -160,7 +160,7 @@ export const handlePoints = (object3D: Object3DScene): ScenePoint[] => { // Draw Points return truncatedPoints.map(point => { const [x, y, z, r, g, b] = point; - const position: Position = [x, y, z]; + const position: Position = [x ?? 0, y ?? 0, z ?? 0]; const category = r; if (r !== undefined && g !== undefined && b !== undefined) { diff --git a/weave-js/src/common/util/render_babylon.ts b/weave-js/src/common/util/render_babylon.ts index 10aee3f6c51..ebd213c2677 100644 --- a/weave-js/src/common/util/render_babylon.ts +++ b/weave-js/src/common/util/render_babylon.ts @@ -394,6 +394,15 @@ const pointCloudScene = ( // Apply vertexData to custom mesh vertexData.applyToMesh(pcMesh); + // A file without any points defined still includes a single, empty "point". + // In order to play nice with Babylon, we position this empty point at 0,0,0. + // Hence, a pointCloud with a single point at 0,0,0 is likely empty. + const isEmpty = + pointCloud.points.length === 1 && + pointCloud.points[0].position[0] === 0 && + pointCloud.points[0].position[1] === 0 && + pointCloud.points[0].position[2] === 0; + camera.parent = pcMesh; const pcMaterial = new Babylon.StandardMaterial('mat', scene); @@ -472,8 +481,8 @@ const pointCloudScene = ( new Vector3(edges.length * 2, edges.length * 2, edges.length * 2) ); - // If we are iterating over camera, target a box - if (index === meta?.cameraIndex) { + // If we are iterating over camera or the cloud is empty, target a box + if (index === meta?.cameraIndex || (index === 0 && isEmpty)) { camera.position = center.add(new Vector3(0, 0, 1000)); camera.target = center; camera.zoomOn([lines]); From 0966c9fb21d00fa6b3ed9019de2c4c471fc6ec30 Mon Sep 17 00:00:00 2001 From: Griffin Tarpenning Date: Tue, 10 Dec 2024 12:08:50 -0800 Subject: [PATCH 05/12] fix(ui): feedback header when empty (#3191) --- .../Browse3/feedback/StructuredFeedback/FeedbackSidebar.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/StructuredFeedback/FeedbackSidebar.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/StructuredFeedback/FeedbackSidebar.tsx index ef6bcbd69ff..0b3c9603fef 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/StructuredFeedback/FeedbackSidebar.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/StructuredFeedback/FeedbackSidebar.tsx @@ -96,7 +96,7 @@ export const FeedbackSidebar = ({
Feedback
-
+
{humanAnnotationSpecs.length > 0 ? ( <>
From 35a78d6e83006ea4462d48485651213a4afbd381 Mon Sep 17 00:00:00 2001 From: Griffin Tarpenning Date: Tue, 10 Dec 2024 16:20:17 -0800 Subject: [PATCH 06/12] chore(ui): make eval section header draggable (#2637) --- .../CompareEvaluationsPage.tsx | 14 +-- .../compareEvaluationsContext.tsx | 25 ++-- .../pages/CompareEvaluationsPage/ecpState.ts | 63 +++++++--- .../pages/CompareEvaluationsPage/ecpTypes.ts | 1 + .../ComparisonDefinitionSection.tsx | 110 ++++++++++++------ .../EvaluationDefinition.tsx | 83 +------------ .../DraggableSection/DraggableItem.tsx | 103 ++++++++++++++++ .../DraggableSection/DraggableSection.tsx | 34 ++++++ .../ExampleCompareSection.tsx | 2 +- .../ExampleFilterSection.tsx | 4 +- .../ScorecardSection/ScorecardSection.tsx | 3 +- 11 files changed, 280 insertions(+), 162 deletions(-) create mode 100644 weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/DraggableSection/DraggableItem.tsx create mode 100644 weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/DraggableSection/DraggableSection.tsx diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/CompareEvaluationsPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/CompareEvaluationsPage.tsx index b302b0262ae..478c4887546 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/CompareEvaluationsPage.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/CompareEvaluationsPage.tsx @@ -77,8 +77,6 @@ export const CompareEvaluationsPage: React.FC< export const CompareEvaluationsPageContent: React.FC< CompareEvaluationsPageProps > = props => { - const [baselineEvaluationCallId, setBaselineEvaluationCallId] = - React.useState(null); const [comparisonDimensions, setComparisonDimensions] = React.useState(null); @@ -104,14 +102,6 @@ export const CompareEvaluationsPageContent: React.FC< [comparisonDimensions] ); - React.useEffect(() => { - // Only update the baseline if we are switching evaluations, if there - // is more than 1, we are in the compare view and baseline is auto set - if (props.evaluationCallIds.length === 1) { - setBaselineEvaluationCallId(props.evaluationCallIds[0]); - } - }, [props.evaluationCallIds]); - if (props.evaluationCallIds.length === 0) { return
No evaluations to compare
; } @@ -120,13 +110,11 @@ export const CompareEvaluationsPageContent: React.FC< diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/compareEvaluationsContext.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/compareEvaluationsContext.tsx index b65658a890a..638565e8ad6 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/compareEvaluationsContext.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/compareEvaluationsContext.tsx @@ -10,9 +10,6 @@ import {ComparisonDimensionsType} from './ecpState'; const CompareEvaluationsContext = React.createContext<{ state: EvaluationComparisonState; - setBaselineEvaluationCallId: React.Dispatch< - React.SetStateAction - >; setComparisonDimensions: React.Dispatch< React.SetStateAction >; @@ -20,6 +17,7 @@ const CompareEvaluationsContext = React.createContext<{ setSelectedMetrics: (newModel: Record) => void; addEvaluationCall: (newCallId: string) => void; removeEvaluationCall: (callId: string) => void; + setEvaluationCallOrder: (newCallIdOrder: string[]) => void; } | null>(null); export const useCompareEvaluationsState = () => { @@ -33,34 +31,26 @@ export const useCompareEvaluationsState = () => { export const CompareEvaluationsProvider: React.FC<{ entity: string; project: string; + initialEvaluationCallIds: string[]; selectedMetrics: Record | null; setSelectedMetrics: (newModel: Record) => void; - initialEvaluationCallIds: string[]; onEvaluationCallIdsUpdate: (newEvaluationCallIds: string[]) => void; - setBaselineEvaluationCallId: React.Dispatch< - React.SetStateAction - >; setComparisonDimensions: React.Dispatch< React.SetStateAction >; setSelectedInputDigest: React.Dispatch>; - baselineEvaluationCallId?: string; comparisonDimensions?: ComparisonDimensionsType; selectedInputDigest?: string; }> = ({ entity, project, + initialEvaluationCallIds, selectedMetrics, setSelectedMetrics, - - initialEvaluationCallIds, onEvaluationCallIdsUpdate, - setBaselineEvaluationCallId, setComparisonDimensions, - setSelectedInputDigest, - baselineEvaluationCallId, comparisonDimensions, selectedInputDigest, children, @@ -77,7 +67,6 @@ export const CompareEvaluationsProvider: React.FC<{ entity, project, evaluationCallIds, - baselineEvaluationCallId, comparisonDimensions, selectedInputDigest, selectedMetrics ?? undefined @@ -89,7 +78,6 @@ export const CompareEvaluationsProvider: React.FC<{ } return { state: initialState.result, - setBaselineEvaluationCallId, setComparisonDimensions, setSelectedInputDigest, setSelectedMetrics, @@ -105,14 +93,17 @@ export const CompareEvaluationsProvider: React.FC<{ setEvaluationCallIds(newEvaluationCallIds); onEvaluationCallIdsUpdate(newEvaluationCallIds); }, + setEvaluationCallOrder: (newCallIdOrder: string[]) => { + setEvaluationCallIds(newCallIdOrder); + onEvaluationCallIdsUpdate(newCallIdOrder); + }, }; }, [ initialState.loading, initialState.result, + setEvaluationCallIds, evaluationCallIds, onEvaluationCallIdsUpdate, - setEvaluationCallIds, - setBaselineEvaluationCallId, setComparisonDimensions, setSelectedInputDigest, setSelectedMetrics, diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpState.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpState.ts index 35f95dbf14f..e5c1b03d60a 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpState.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpState.ts @@ -18,14 +18,14 @@ import {getMetricIds} from './ecpUtil'; export type EvaluationComparisonState = { // The normalized data for the evaluations data: EvaluationComparisonData; - // The evaluation call id of the baseline model - baselineEvaluationCallId: string; // The dimensions to compare & filter results comparisonDimensions?: ComparisonDimensionsType; // The current digest which is in view selectedInputDigest?: string; // The selected metrics to display selectedMetrics?: Record; + // Ordered call Ids + evaluationCallIdsOrdered: string[]; }; export type ComparisonDimensionsType = Array<{ @@ -43,12 +43,14 @@ export const useEvaluationComparisonState = ( entity: string, project: string, evaluationCallIds: string[], - baselineEvaluationCallId?: string, comparisonDimensions?: ComparisonDimensionsType, selectedInputDigest?: string, selectedMetrics?: Record ): Loadable => { - const data = useEvaluationComparisonData(entity, project, evaluationCallIds); + const orderedCallIds = useMemo(() => { + return getCallIdsOrderedForQuery(evaluationCallIds); + }, [evaluationCallIds]); + const data = useEvaluationComparisonData(entity, project, orderedCallIds); const value = useMemo(() => { if (data.result == null || data.loading) { @@ -92,42 +94,45 @@ export const useEvaluationComparisonState = ( loading: false, result: { data: data.result, - baselineEvaluationCallId: - baselineEvaluationCallId ?? evaluationCallIds[0], comparisonDimensions: newComparisonDimensions, selectedInputDigest, selectedMetrics, + evaluationCallIdsOrdered: evaluationCallIds, }, }; }, [ data.result, data.loading, - baselineEvaluationCallId, - evaluationCallIds, comparisonDimensions, selectedInputDigest, selectedMetrics, + evaluationCallIds, ]); return value; }; +export const getOrderedCallIds = (state: EvaluationComparisonState) => { + return Array.from(state.evaluationCallIdsOrdered); +}; + +export const getBaselineCallId = (state: EvaluationComparisonState) => { + return getOrderedCallIds(state)[0]; +}; + /** - * Should use this over keys of `state.data.evaluationCalls` because it ensures the baseline - * evaluation call is first. + * Sort call IDs to ensure consistent order for memoized query params */ -export const getOrderedCallIds = (state: EvaluationComparisonState) => { - const initial = Object.keys(state.data.evaluationCalls); - moveItemToFront(initial, state.baselineEvaluationCallId); - return initial; +const getCallIdsOrderedForQuery = (callIds: string[]) => { + return Array.from(callIds).sort(); }; /** * Should use this over keys of `state.data.models` because it ensures the baseline model is first. */ export const getOrderedModelRefs = (state: EvaluationComparisonState) => { - const baselineRef = - state.data.evaluationCalls[state.baselineEvaluationCallId].modelRef; + const baselineCallId = getBaselineCallId(state); + const baselineRef = state.data.evaluationCalls[baselineCallId].modelRef; const refs = Object.keys(state.data.models); // Make sure the baseline model is first moveItemToFront(refs, baselineRef); @@ -145,3 +150,29 @@ const moveItemToFront = (arr: T[], item: T): T[] => { } return arr; }; + +export const getOrderedEvalsWithNewBaseline = ( + evaluationCallIds: string[], + newBaselineCallId: string +) => { + return moveItemToFront(evaluationCallIds, newBaselineCallId); +}; + +export const swapEvaluationCalls = ( + evaluationCallIds: string[], + ndx1: number, + ndx2: number +): string[] => { + return swapArrayItems(evaluationCallIds, ndx1, ndx2); +}; + +const swapArrayItems = (arr: T[], ndx1: number, ndx2: number): T[] => { + if (ndx1 < 0 || ndx2 < 0 || ndx1 >= arr.length || ndx2 >= arr.length) { + throw new Error('Index out of bounds'); + } + const newArr = [...arr]; + const from = newArr[ndx1]; + newArr[ndx1] = newArr[ndx2]; + newArr[ndx2] = from; + return newArr; +}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpTypes.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpTypes.ts index 7454e0707b4..b4642fae240 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpTypes.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpTypes.ts @@ -18,6 +18,7 @@ export type EvaluationComparisonData = { }; // EvaluationCalls are the specific calls of an evaluation. + // The visual order of the evaluation calls is determined by the order of the keys. evaluationCalls: { [callId: string]: EvaluationCall; }; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ComparisonDefinitionSection/ComparisonDefinitionSection.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ComparisonDefinitionSection/ComparisonDefinitionSection.tsx index 66ce56c4cb0..2704a66cbea 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ComparisonDefinitionSection/ComparisonDefinitionSection.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ComparisonDefinitionSection/ComparisonDefinitionSection.tsx @@ -12,44 +12,81 @@ import { import {useCallsForQuery} from '../../../CallsPage/callsTableQuery'; import {useEvaluationsFilter} from '../../../CallsPage/evaluationsFilter'; import {Id} from '../../../common/Id'; +import {opNiceName} from '../../../common/Links'; import {useWFHooks} from '../../../wfReactInterface/context'; import { CallSchema, ObjectVersionKey, } from '../../../wfReactInterface/wfDataModelHooksInterface'; import {useCompareEvaluationsState} from '../../compareEvaluationsContext'; -import {STANDARD_PADDING} from '../../ecpConstants'; -import {getOrderedCallIds} from '../../ecpState'; -import {EvaluationComparisonState} from '../../ecpState'; +import { + EvaluationComparisonState, + getOrderedCallIds, + getOrderedEvalsWithNewBaseline, + swapEvaluationCalls, +} from '../../ecpState'; import {HorizontalBox} from '../../Layout'; -import {EvaluationDefinition, VerticalBar} from './EvaluationDefinition'; +import {ItemDef} from '../DraggableSection/DraggableItem'; +import {DraggableSection} from '../DraggableSection/DraggableSection'; +import {VerticalBar} from './EvaluationDefinition'; export const ComparisonDefinitionSection: React.FC<{ state: EvaluationComparisonState; }> = props => { - const evalCallIds = useMemo( - () => getOrderedCallIds(props.state), - [props.state] - ); + const {setEvaluationCallOrder, removeEvaluationCall} = + useCompareEvaluationsState(); + + const callIds = useMemo(() => { + return getOrderedCallIds(props.state); + }, [props.state]); + + const items: ItemDef[] = useMemo(() => { + return callIds.map(callId => ({ + key: 'evaluations', + value: callId, + label: props.state.data.evaluationCalls[callId]?.name ?? callId, + })); + }, [callIds, props.state.data.evaluationCalls]); + + const onSetBaseline = (value: string | null) => { + if (!value) { + return; + } + const newSortOrder = getOrderedEvalsWithNewBaseline(callIds, value); + setEvaluationCallOrder(newSortOrder); + }; + const onRemoveItem = (value: string) => removeEvaluationCall(value); + const onSortEnd = ({ + oldIndex, + newIndex, + }: { + oldIndex: number; + newIndex: number; + }) => { + const newSortOrder = swapEvaluationCalls(callIds, oldIndex, newIndex); + setEvaluationCallOrder(newSortOrder); + }; return ( - - {evalCallIds.map((key, ndx) => { - return ( - - - - ); - })} - - + +
+ + + + + + +
+
); }; @@ -81,7 +118,7 @@ const ModelRefLabel: React.FC<{modelRef: string}> = props => { const objectVersion = useObjectVersion(objVersionKey); return ( - {objectVersion.result?.objectId}:{objectVersion.result?.versionIndex} + {objectVersion.result?.objectId}:v{objectVersion.result?.versionIndex} ); }; @@ -119,10 +156,9 @@ const AddEvaluationButton: React.FC<{ const evalsNotComparing = useMemo(() => { return calls.result.filter( - call => - !Object.keys(props.state.data.evaluationCalls).includes(call.callId) + call => !getOrderedCallIds(props.state).includes(call.callId) ); - }, [calls.result, props.state.data.evaluationCalls]); + }, [calls.result, props.state]); const [menuOptions, setMenuOptions] = useState(evalsNotComparing); @@ -222,12 +258,18 @@ const AddEvaluationButton: React.FC<{ variant="ghost" size="small" className="pb-8 pt-8 font-['Source_Sans_Pro'] text-base font-normal text-moon-800" - onClick={() => { - addEvaluationCall(call.callId); - }}> + onClick={() => addEvaluationCall(call.callId)}> <> - {call.displayName ?? call.spanName} - + + {call.displayName ?? opNiceName(call.spanName)} + + + + diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ComparisonDefinitionSection/EvaluationDefinition.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ComparisonDefinitionSection/EvaluationDefinition.tsx index adc80d044f6..5dcf835e378 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ComparisonDefinitionSection/EvaluationDefinition.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ComparisonDefinitionSection/EvaluationDefinition.tsx @@ -1,8 +1,5 @@ import {Box} from '@material-ui/core'; import {Circle} from '@mui/icons-material'; -import {PopupDropdown} from '@wandb/weave/common/components/PopupDropdown'; -import {Button} from '@wandb/weave/components/Button'; -import {Pill} from '@wandb/weave/components/Tag'; import React, {useMemo} from 'react'; import { @@ -17,87 +14,17 @@ import {SmallRef} from '../../../../../Browse2/SmallRef'; import {CallLink, ObjectVersionLink} from '../../../common/Links'; import {useWFHooks} from '../../../wfReactInterface/context'; import {ObjectVersionKey} from '../../../wfReactInterface/wfDataModelHooksInterface'; -import {useCompareEvaluationsState} from '../../compareEvaluationsContext'; -import { - BOX_RADIUS, - CIRCLE_SIZE, - EVAL_DEF_HEIGHT, - STANDARD_BORDER, -} from '../../ecpConstants'; +import {CIRCLE_SIZE} from '../../ecpConstants'; import {EvaluationComparisonState} from '../../ecpState'; -import {HorizontalBox} from '../../Layout'; - -export const EvaluationDefinition: React.FC<{ - state: EvaluationComparisonState; - callId: string; -}> = props => { - const {removeEvaluationCall, setBaselineEvaluationCallId} = - useCompareEvaluationsState(); - - const menuOptions = useMemo(() => { - return [ - { - key: 'add-to-baseline', - content: 'Set as baseline', - onClick: () => { - setBaselineEvaluationCallId(props.callId); - }, - disabled: props.callId === props.state.baselineEvaluationCallId, - }, - { - key: 'remove', - content: 'Remove', - onClick: () => { - removeEvaluationCall(props.callId); - }, - disabled: Object.keys(props.state.data.evaluationCalls).length === 1, - }, - ]; - }, [ - props.callId, - props.state.baselineEvaluationCallId, - props.state.data.evaluationCalls, - removeEvaluationCall, - setBaselineEvaluationCallId, - ]); - - return ( - - - {props.callId === props.state.baselineEvaluationCallId && ( - - )} -
- - } - /> -
-
- ); -}; export const EvaluationCallLink: React.FC<{ callId: string; state: EvaluationComparisonState; }> = props => { - const evaluationCall = props.state.data.evaluationCalls[props.callId]; + const evaluationCall = props.state.data.evaluationCalls?.[props.callId]; + if (!evaluationCall) { + return null; + } const {entity, project} = props.state.data; return ( diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/DraggableSection/DraggableItem.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/DraggableSection/DraggableItem.tsx new file mode 100644 index 00000000000..1510c502b99 --- /dev/null +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/DraggableSection/DraggableItem.tsx @@ -0,0 +1,103 @@ +import {Button} from '@wandb/weave/components/Button'; +import * as DropdownMenu from '@wandb/weave/components/DropdownMenu'; +import {Icon} from '@wandb/weave/components/Icon'; +import {Pill} from '@wandb/weave/components/Tag/Pill'; +import {Tailwind} from '@wandb/weave/components/Tailwind'; +import classNames from 'classnames'; +import React, {useState} from 'react'; +import {SortableElement, SortableHandle} from 'react-sortable-hoc'; + +import {EvaluationComparisonState} from '../../ecpState'; +import {EvaluationCallLink} from '../ComparisonDefinitionSection/EvaluationDefinition'; + +export type ItemDef = { + key: string; + value: string; + label?: string; +}; + +type DraggableItemProps = { + state: EvaluationComparisonState; + item: ItemDef; + numItems: number; + idx: number; + onRemoveItem: (value: string) => void; + onSetBaseline: (value: string | null) => void; +}; + +export const DraggableItem = SortableElement( + ({ + state, + item, + numItems, + idx, + onRemoveItem, + onSetBaseline, + }: DraggableItemProps) => { + const isDeletable = numItems > 1; + const isBaseline = idx === 0; + const [isOpen, setIsOpen] = useState(false); + + const onMakeBaselinePropagated = (e: React.MouseEvent) => { + e.stopPropagation(); + onSetBaseline(item.value); + }; + + const onRemoveItemPropagated = (e: React.MouseEvent) => { + e.stopPropagation(); + onRemoveItem(item.value); + }; + + return ( + +
+ +
+ + {isBaseline && ( + + )} +
+ + +
+
+ ); + } +); + +const DragHandle = SortableHandle(() => ( +
+ +
+)); diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/DraggableSection/DraggableSection.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/DraggableSection/DraggableSection.tsx new file mode 100644 index 00000000000..23a03ceb5b3 --- /dev/null +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/DraggableSection/DraggableSection.tsx @@ -0,0 +1,34 @@ +import React from 'react'; +import {SortableContainer} from 'react-sortable-hoc'; + +import {EvaluationComparisonState} from '../../ecpState'; +import {DraggableItem} from './DraggableItem'; +import {ItemDef} from './DraggableItem'; + +type DraggableSectionProps = { + state: EvaluationComparisonState; + items: ItemDef[]; + onSetBaseline: (value: string | null) => void; + onRemoveItem: (value: string) => void; +}; + +export const DraggableSection = SortableContainer( + ({state, items, onSetBaseline, onRemoveItem}: DraggableSectionProps) => { + return ( +
+ {items.map((item, index) => ( + + ))} +
+ ); + } +); diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleCompareSection/ExampleCompareSection.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleCompareSection/ExampleCompareSection.tsx index 6041492b5c5..398f65ecd45 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleCompareSection/ExampleCompareSection.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleCompareSection/ExampleCompareSection.tsx @@ -149,7 +149,7 @@ const stickySidebarHeaderMixin: React.CSSProperties = { /** * This component will occupy the entire space provided by the parent container. - * It is intended to be used in teh CompareEvaluations page, as it depends on + * It is intended to be used in the CompareEvaluations page, as it depends on * the EvaluationComparisonState. However, in principle, it is a general purpose * model-output comparison tool. It allows the user to view inputs, then compare * model outputs and evaluation metrics across multiple trials. diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleFilterSection/ExampleFilterSection.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleFilterSection/ExampleFilterSection.tsx index b2c8773a7d1..1146c8ea960 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleFilterSection/ExampleFilterSection.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleFilterSection/ExampleFilterSection.tsx @@ -13,7 +13,7 @@ import { } from '../../compositeMetricsUtil'; import {PLOT_HEIGHT, STANDARD_PADDING} from '../../ecpConstants'; import {MAX_PLOT_DOT_SIZE, MIN_PLOT_DOT_SIZE} from '../../ecpConstants'; -import {EvaluationComparisonState} from '../../ecpState'; +import {EvaluationComparisonState, getBaselineCallId} from '../../ecpState'; import {metricDefinitionId} from '../../ecpUtil'; import { flattenedDimensionPath, @@ -103,7 +103,7 @@ const SingleDimensionFilter: React.FC<{ }, [props.state.data]); const {setComparisonDimensions} = useCompareEvaluationsState(); - const baselineCallId = props.state.baselineEvaluationCallId; + const baselineCallId = getBaselineCallId(props.state); const compareCallId = Object.keys(props.state.data.evaluationCalls).find( callId => callId !== baselineCallId )!; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ScorecardSection/ScorecardSection.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ScorecardSection/ScorecardSection.tsx index 4b319150fa8..303f640f47e 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ScorecardSection/ScorecardSection.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ScorecardSection/ScorecardSection.tsx @@ -32,6 +32,7 @@ import { } from '../../ecpConstants'; import { EvaluationComparisonState, + getBaselineCallId, getOrderedCallIds, getOrderedModelRefs, } from '../../ecpState'; @@ -414,7 +415,7 @@ export const ScorecardSection: React.FC<{ {evalCallIds.map((evalCallId, mNdx) => { const baseline = resolveSummaryMetricResult( - props.state.baselineEvaluationCallId, + getBaselineCallId(props.state), groupName, metricKey, compositeSummaryMetrics, From f967984c8bf49bf49e1e5a500b338381f2481e4e Mon Sep 17 00:00:00 2001 From: Erica Diaz <156136421+ericakdiaz@users.noreply.github.com> Date: Tue, 10 Dec 2024 16:23:39 -0800 Subject: [PATCH 07/12] chore(weave): Update ToggleButtonGroup to allow disabling individual options (#3194) --- weave-js/src/components/ToggleButtonGroup.tsx | 67 +++++++++++-------- 1 file changed, 38 insertions(+), 29 deletions(-) diff --git a/weave-js/src/components/ToggleButtonGroup.tsx b/weave-js/src/components/ToggleButtonGroup.tsx index 65b2c538975..eca93e95657 100644 --- a/weave-js/src/components/ToggleButtonGroup.tsx +++ b/weave-js/src/components/ToggleButtonGroup.tsx @@ -9,6 +9,7 @@ import {Tailwind} from './Tailwind'; export type ToggleOption = { value: string; icon?: IconName; + isDisabled?: boolean; }; export type ToggleButtonGroupProps = { @@ -37,7 +38,10 @@ export const ToggleButtonGroup = React.forwardRef< } const handleValueChange = (newValue: string) => { - if (newValue !== value) { + if ( + newValue !== value && + options.find(option => option.value === newValue)?.isDisabled !== true + ) { onValueChange(newValue); } }; @@ -49,34 +53,39 @@ export const ToggleButtonGroup = React.forwardRef< onValueChange={handleValueChange} className="flex gap-px" ref={ref}> - {options.map(({value: optionValue, icon}) => ( - - - - ))} + {options.map( + ({value: optionValue, icon, isDisabled: optionIsDisabled}) => ( + + + + ) + )} ); From 4e1e2c9a6da47ceff2fbd113053d959b43b4ce6f Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Tue, 10 Dec 2024 16:30:13 -0800 Subject: [PATCH 08/12] chore(weave): Fix predict detection heuristic for evals (#3196) * init * lint --- .../tsDataModelHooksEvaluationComparison.ts | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/tsDataModelHooksEvaluationComparison.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/tsDataModelHooksEvaluationComparison.ts index 8418f2976ad..847421464e3 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/tsDataModelHooksEvaluationComparison.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/tsDataModelHooksEvaluationComparison.ts @@ -69,7 +69,6 @@ * across different datasets. */ -import _ from 'lodash'; import {sum} from 'lodash'; import {useEffect, useMemo, useRef, useState} from 'react'; @@ -479,17 +478,8 @@ const fetchEvaluationComparisonData = async ( const maybeDigest = parts[1]; if (maybeDigest != null && !maybeDigest.includes('/')) { const rowDigest = maybeDigest; - const possiblePredictNames = [ - 'predict', - 'infer', - 'forward', - 'invoke', - ]; const isProbablyPredictCall = - (_.some(possiblePredictNames, name => - traceCall.op_name.includes(`.${name}:`) - ) && - modelRefs.includes(traceCall.inputs.self)) || + modelRefs.includes(traceCall.inputs.self) || modelRefs.includes(traceCall.op_name); const isProbablyScoreCall = scorerRefs.has(traceCall.op_name); From bf39669bf3ebba977e9b6cd9224c9abce22a3a62 Mon Sep 17 00:00:00 2001 From: Griffin Tarpenning Date: Tue, 10 Dec 2024 16:36:47 -0800 Subject: [PATCH 09/12] fix(ui): Evaluations -- normalize radar data, show real values for bar charts (#3199) --- .../SummaryPlotsSection.tsx | 94 ++++++++++--------- 1 file changed, 50 insertions(+), 44 deletions(-) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/SummaryPlotsSection.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/SummaryPlotsSection.tsx index c23ffcce04d..02c456df850 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/SummaryPlotsSection.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/SummaryPlotsSection.tsx @@ -29,7 +29,7 @@ export const SummaryPlots: React.FC<{ state: EvaluationComparisonState; setSelectedMetrics: (newModel: Record) => void; }> = ({state, setSelectedMetrics}) => { - const {radarData, allMetricNames} = useNormalizedPlotDataFromMetrics(state); + const {radarData, allMetricNames} = usePlotDataFromMetrics(state); const {selectedMetrics} = state; // Initialize selectedMetrics if null @@ -237,11 +237,10 @@ const useFilteredData = ( return data; }, [radarData, selectedMetrics]); -function getMetricValuesFromRadarData(radarData: RadarPlotData): { +function getMetricValuesMap(radarData: RadarPlotData): { [metric: string]: number[]; } { const metricValues: {[metric: string]: number[]} = {}; - // Gather all values for each metric Object.values(radarData).forEach(callData => { Object.entries(callData.metrics).forEach(([metric, value]) => { if (!metricValues[metric]) { @@ -253,37 +252,54 @@ function getMetricValuesFromRadarData(radarData: RadarPlotData): { return metricValues; } -function getMetricMinsFromRadarData(radarData: RadarPlotData): { - [metric: string]: number; +function normalizeMetricValues(values: number[]): { + normalizedValues: number[]; + normalizer: number; } { - const metricValues = getMetricValuesFromRadarData(radarData); - const metricMins: {[metric: string]: number} = {}; - Object.entries(metricValues).forEach(([metric, values]) => { - metricMins[metric] = Math.min(...values); - }); - return metricMins; + const min = Math.min(...values); + const max = Math.max(...values); + + if (min === max) { + return { + normalizedValues: values.map(() => 0.5), + normalizer: 1, + }; + } + + // Handle negative values by shifting + const shiftedValues = min < 0 ? values.map(v => v - min) : values; + const maxValue = min < 0 ? max - min : max; + + const maxPower = Math.ceil(Math.log2(maxValue)); + const normalizer = Math.pow(2, maxPower); + + return { + normalizedValues: shiftedValues.map(v => v / normalizer), + normalizer, + }; } -function normalizeDataForRadarPlot(radarData: RadarPlotData): RadarPlotData { - const metricMins = getMetricMinsFromRadarData(radarData); +function normalizeDataForRadarPlot( + radarDataOriginal: RadarPlotData +): RadarPlotData { + const radarData = Object.fromEntries( + Object.entries(radarDataOriginal).map(([callId, callData]) => [ + callId, + {...callData, metrics: {...callData.metrics}}, + ]) + ); - const normalizedData: RadarPlotData = {}; - Object.entries(radarData).forEach(([callId, callData]) => { - normalizedData[callId] = { - name: callData.name, - color: callData.color, - metrics: {}, - }; + const metricValues = getMetricValuesMap(radarData); - Object.entries(callData.metrics).forEach(([metric, value]) => { - const min = metricMins[metric]; - // Only shift values if there are negative values - const normalizedValue = min < 0 ? value - min : value; - normalizedData[callId].metrics[metric] = normalizedValue; + // Normalize each metric independently + Object.entries(metricValues).forEach(([metric, values]) => { + const {normalizedValues} = normalizeMetricValues(values); + Object.values(radarData).forEach((callData, index) => { + callData.metrics[metric] = normalizedValues[index]; }); }); - return normalizedData; + return radarData; } const useBarPlotData = (filteredData: RadarPlotData) => @@ -317,7 +333,9 @@ const useBarPlotData = (filteredData: RadarPlotData) => type: 'bar', y: metricBin.values, x: metricBin.callIds, - text: metricBin.values.map(value => value.toFixed(3)), + text: metricBin.values.map(value => + Number.isInteger(value) ? value.toString() : value.toFixed(3) + ), textposition: 'outside', textfont: {size: 14, color: 'black'}, name: metric, @@ -408,16 +426,7 @@ const usePaginatedPlots = ( return {plotsToShow, totalPlots, startIndex, endIndex, totalPages}; }; -function normalizeValues(values: Array): number[] { - // find the max value - // find the power of 2 that is greater than the max value - // divide all values by that power of 2 - const maxVal = Math.max(...(values.filter(v => v !== undefined) as number[])); - const maxPower = Math.ceil(Math.log2(maxVal)); - return values.map(val => (val ? val / 2 ** maxPower : 0)); -} - -const useNormalizedPlotDataFromMetrics = ( +const usePlotDataFromMetrics = ( state: EvaluationComparisonState ): {radarData: RadarPlotData; allMetricNames: Set} => { const compositeMetrics = useMemo(() => { @@ -428,7 +437,7 @@ const useNormalizedPlotDataFromMetrics = ( }, [state]); return useMemo(() => { - const normalizedMetrics = Object.values(compositeMetrics) + const metrics = Object.values(compositeMetrics) .map(scoreGroup => Object.values(scoreGroup.metrics)) .flat() .map(metric => { @@ -449,11 +458,8 @@ const useNormalizedPlotDataFromMetrics = ( return val; } }); - const normalizedValues = normalizeValues(values); const evalScores: {[evalCallId: string]: number | undefined} = - Object.fromEntries( - callIds.map((key, i) => [key, normalizedValues[i]]) - ); + Object.fromEntries(callIds.map((key, i) => [key, values[i]])); const metricLabel = flattenedDimensionPath( Object.values(metric.scorerRefs)[0].metric @@ -472,7 +478,7 @@ const useNormalizedPlotDataFromMetrics = ( name: evalCall.name, color: evalCall.color, metrics: Object.fromEntries( - normalizedMetrics.map(metric => { + metrics.map(metric => { return [ metric.metricLabel, metric.evalScores[evalCall.callId] ?? 0, @@ -483,7 +489,7 @@ const useNormalizedPlotDataFromMetrics = ( ]; }) ); - const allMetricNames = new Set(normalizedMetrics.map(m => m.metricLabel)); + const allMetricNames = new Set(metrics.map(m => m.metricLabel)); return {radarData, allMetricNames}; }, [callIds, compositeMetrics, state.data.evaluationCalls]); }; From af81503aa330c0659a8c65b316b1071782a543e6 Mon Sep 17 00:00:00 2001 From: Jeff Raubitschek Date: Wed, 11 Dec 2024 12:34:22 -0800 Subject: [PATCH 10/12] chore(weave): update db to use replicated tables (#3148) --- .../test_clickhouse_trace_server_migrator.py | 230 ++++++++++++++++++ .../clickhouse_trace_server_migrator.py | 132 ++++++++-- 2 files changed, 336 insertions(+), 26 deletions(-) create mode 100644 tests/trace_server/test_clickhouse_trace_server_migrator.py diff --git a/tests/trace_server/test_clickhouse_trace_server_migrator.py b/tests/trace_server/test_clickhouse_trace_server_migrator.py new file mode 100644 index 00000000000..3a6a92f2479 --- /dev/null +++ b/tests/trace_server/test_clickhouse_trace_server_migrator.py @@ -0,0 +1,230 @@ +import types +from unittest.mock import Mock, call, patch + +import pytest + +from weave.trace_server import clickhouse_trace_server_migrator as trace_server_migrator +from weave.trace_server.clickhouse_trace_server_migrator import MigrationError + + +@pytest.fixture +def mock_costs(): + with patch( + "weave.trace_server.costs.insert_costs.should_insert_costs", return_value=False + ) as mock_should_insert: + with patch( + "weave.trace_server.costs.insert_costs.get_current_costs", return_value=[] + ) as mock_get_costs: + yield + + +@pytest.fixture +def migrator(): + ch_client = Mock() + migrator = trace_server_migrator.ClickHouseTraceServerMigrator(ch_client) + migrator._get_migration_status = Mock() + migrator._get_migrations = Mock() + migrator._determine_migrations_to_apply = Mock() + migrator._update_migration_status = Mock() + ch_client.command.reset_mock() + return migrator + + +def test_apply_migrations_with_target_version(mock_costs, migrator, tmp_path): + # Setup + migrator._get_migration_status.return_value = { + "curr_version": 1, + "partially_applied_version": None, + } + migrator._get_migrations.return_value = { + "1": {"up": "1.up.sql", "down": "1.down.sql"}, + "2": {"up": "2.up.sql", "down": "2.down.sql"}, + } + migrator._determine_migrations_to_apply.return_value = [(2, "2.up.sql")] + + # Create a temporary migration file + migration_dir = tmp_path / "migrations" + migration_dir.mkdir() + migration_file = migration_dir / "2.up.sql" + migration_file.write_text( + "CREATE TABLE test1 (id Int32);\nCREATE TABLE test2 (id Int32);" + ) + + # Mock the migration directory path + with patch("os.path.dirname") as mock_dirname: + mock_dirname.return_value = str(tmp_path) + + # Execute + migrator.apply_migrations("test_db", target_version=2) + + # Verify + migrator._get_migration_status.assert_called_once_with("test_db") + migrator._get_migrations.assert_called_once() + migrator._determine_migrations_to_apply.assert_called_once_with( + 1, migrator._get_migrations.return_value, 2 + ) + + # Verify migration execution + assert migrator._update_migration_status.call_count == 2 + migrator._update_migration_status.assert_has_calls( + [call("test_db", 2, is_start=True), call("test_db", 2, is_start=False)] + ) + + # Verify the actual SQL commands were executed + ch_client = migrator.ch_client + assert ch_client.command.call_count == 2 + ch_client.command.assert_has_calls( + [call("CREATE TABLE test1 (id Int32)"), call("CREATE TABLE test2 (id Int32)")] + ) + + +def test_execute_migration_command(mock_costs, migrator): + # Setup + ch_client = migrator.ch_client + ch_client.database = "original_db" + + # Execute + migrator._execute_migration_command("test_db", "CREATE TABLE test (id Int32)") + + # Verify + assert ch_client.database == "original_db" # Should restore original database + ch_client.command.assert_called_once_with("CREATE TABLE test (id Int32)") + + +def test_migration_replicated(mock_costs, migrator): + ch_client = migrator.ch_client + orig = "CREATE TABLE test (id String, project_id String) ENGINE = MergeTree ORDER BY (project_id, id);" + migrator._execute_migration_command("test_db", orig) + ch_client.command.assert_called_once_with(orig) + + +def test_update_migration_status(mock_costs, migrator): + # Don't mock _update_migration_status for this test + migrator._update_migration_status = types.MethodType( + trace_server_migrator.ClickHouseTraceServerMigrator._update_migration_status, + migrator, + ) + + # Test start of migration + migrator._update_migration_status("test_db", 2, is_start=True) + migrator.ch_client.command.assert_called_with( + "ALTER TABLE db_management.migrations UPDATE partially_applied_version = 2 WHERE db_name = 'test_db'" + ) + + # Test end of migration + migrator._update_migration_status("test_db", 2, is_start=False) + migrator.ch_client.command.assert_called_with( + "ALTER TABLE db_management.migrations UPDATE curr_version = 2, partially_applied_version = NULL WHERE db_name = 'test_db'" + ) + + +def test_is_safe_identifier(mock_costs, migrator): + # Valid identifiers + assert migrator._is_safe_identifier("test_db") + assert migrator._is_safe_identifier("my_db123") + assert migrator._is_safe_identifier("db.table") + + # Invalid identifiers + assert not migrator._is_safe_identifier("test-db") + assert not migrator._is_safe_identifier("db;") + assert not migrator._is_safe_identifier("db'name") + assert not migrator._is_safe_identifier("db/*") + + +def test_create_db_sql_validation(mock_costs, migrator): + # Test invalid database name + with pytest.raises(MigrationError, match="Invalid database name"): + migrator._create_db_sql("test;db") + + # Test replicated mode with invalid values + migrator.replicated = True + migrator.replicated_cluster = "test;cluster" + with pytest.raises(MigrationError, match="Invalid cluster name"): + migrator._create_db_sql("test_db") + + migrator.replicated_cluster = "test_cluster" + migrator.replicated_path = "/clickhouse/bad;path/{db}" + with pytest.raises(MigrationError, match="Invalid replicated path"): + migrator._create_db_sql("test_db") + + +def test_create_db_sql_non_replicated(mock_costs, migrator): + # Test non-replicated mode + migrator.replicated = False + sql = migrator._create_db_sql("test_db") + assert sql.strip() == "CREATE DATABASE IF NOT EXISTS test_db" + + +def test_create_db_sql_replicated(mock_costs, migrator): + # Test replicated mode + migrator.replicated = True + migrator.replicated_path = "/clickhouse/tables/{db}" + migrator.replicated_cluster = "test_cluster" + + sql = migrator._create_db_sql("test_db") + expected = """ + CREATE DATABASE IF NOT EXISTS test_db ON CLUSTER test_cluster ENGINE=Replicated('/clickhouse/tables/test_db', '{shard}', '{replica}') + """.strip() + assert sql.strip() == expected + + +def test_format_replicated_sql_non_replicated(mock_costs, migrator): + # Test that SQL is unchanged when replicated=False + migrator.replicated = False + test_cases = [ + "CREATE TABLE test (id Int32) ENGINE = MergeTree", + "CREATE TABLE test (id Int32) ENGINE = SummingMergeTree", + "CREATE TABLE test (id Int32) ENGINE=ReplacingMergeTree", + ] + + for sql in test_cases: + assert migrator._format_replicated_sql(sql) == sql + + +def test_format_replicated_sql_replicated(mock_costs, migrator): + # Test that MergeTree engines are converted to Replicated variants + migrator.replicated = True + + test_cases = [ + ( + "CREATE TABLE test (id Int32) ENGINE = MergeTree", + "CREATE TABLE test (id Int32) ENGINE = ReplicatedMergeTree", + ), + ( + "CREATE TABLE test (id Int32) ENGINE = SummingMergeTree", + "CREATE TABLE test (id Int32) ENGINE = ReplicatedSummingMergeTree", + ), + ( + "CREATE TABLE test (id Int32) ENGINE=ReplacingMergeTree", + "CREATE TABLE test (id Int32) ENGINE = ReplicatedReplacingMergeTree", + ), + # Test with extra whitespace + ( + "CREATE TABLE test (id Int32) ENGINE = MergeTree", + "CREATE TABLE test (id Int32) ENGINE = ReplicatedMergeTree", + ), + # Test with parameters + ( + "CREATE TABLE test (id Int32) ENGINE = MergeTree()", + "CREATE TABLE test (id Int32) ENGINE = ReplicatedMergeTree()", + ), + ] + + for input_sql, expected_sql in test_cases: + assert migrator._format_replicated_sql(input_sql) == expected_sql + + +def test_format_replicated_sql_non_mergetree(mock_costs, migrator): + # Test that non-MergeTree engines are left unchanged + migrator.replicated = True + + test_cases = [ + "CREATE TABLE test (id Int32) ENGINE = Memory", + "CREATE TABLE test (id Int32) ENGINE = Log", + "CREATE TABLE test (id Int32) ENGINE = TinyLog", + # This should not be changed as it's not a complete word match + "CREATE TABLE test (id Int32) ENGINE = MyMergeTreeCustom", + ] + + for sql in test_cases: + assert migrator._format_replicated_sql(sql) == sql diff --git a/weave/trace_server/clickhouse_trace_server_migrator.py b/weave/trace_server/clickhouse_trace_server_migrator.py index 30dffe89365..4336630bf50 100644 --- a/weave/trace_server/clickhouse_trace_server_migrator.py +++ b/weave/trace_server/clickhouse_trace_server_migrator.py @@ -1,6 +1,7 @@ # Clickhouse Trace Server Manager import logging import os +import re from typing import Optional from clickhouse_connect.driver.client import Client as CHClient @@ -9,6 +10,11 @@ logger = logging.getLogger(__name__) +# These settings are only used when `replicated` mode is enabled for +# self managed clickhouse instances. +DEFAULT_REPLICATED_PATH = "/clickhouse/tables/{db}" +DEFAULT_REPLICATED_CLUSTER = "weave_cluster" + class MigrationError(RuntimeError): """Raised when a migration error occurs.""" @@ -16,15 +22,77 @@ class MigrationError(RuntimeError): class ClickHouseTraceServerMigrator: ch_client: CHClient + replicated: bool + replicated_path: str + replicated_cluster: str def __init__( self, ch_client: CHClient, + replicated: Optional[bool] = None, + replicated_path: Optional[str] = None, + replicated_cluster: Optional[str] = None, ): super().__init__() self.ch_client = ch_client + self.replicated = False if replicated is None else replicated + self.replicated_path = ( + DEFAULT_REPLICATED_PATH if replicated_path is None else replicated_path + ) + self.replicated_cluster = ( + DEFAULT_REPLICATED_CLUSTER + if replicated_cluster is None + else replicated_cluster + ) self._initialize_migration_db() + def _is_safe_identifier(self, value: str) -> bool: + """Check if a string is safe to use as an identifier in SQL.""" + return bool(re.match(r"^[a-zA-Z0-9_\.]+$", value)) + + def _format_replicated_sql(self, sql_query: str) -> str: + """Format SQL query to use replicated engines if replicated mode is enabled.""" + if not self.replicated: + return sql_query + + # Match "ENGINE = MergeTree" followed by word boundary + pattern = r"ENGINE\s*=\s*(\w+)?MergeTree\b" + + def replace_engine(match: re.Match[str]) -> str: + engine_prefix = match.group(1) or "" + return f"ENGINE = Replicated{engine_prefix}MergeTree" + + return re.sub(pattern, replace_engine, sql_query, flags=re.IGNORECASE) + + def _create_db_sql(self, db_name: str) -> str: + """Geneate SQL database create string for normal and replicated databases.""" + if not self._is_safe_identifier(db_name): + raise MigrationError(f"Invalid database name: {db_name}") + + replicated_engine = "" + replicated_cluster = "" + if self.replicated: + if not self._is_safe_identifier(self.replicated_cluster): + raise MigrationError(f"Invalid cluster name: {self.replicated_cluster}") + + replicated_path = self.replicated_path.replace("{db}", db_name) + if not all( + self._is_safe_identifier(part) + for part in replicated_path.split("/") + if part + ): + raise MigrationError(f"Invalid replicated path: {replicated_path}") + + replicated_cluster = f" ON CLUSTER {self.replicated_cluster}" + replicated_engine = ( + f" ENGINE=Replicated('{replicated_path}', '{{shard}}', '{{replica}}')" + ) + + create_db_sql = f""" + CREATE DATABASE IF NOT EXISTS {db_name}{replicated_cluster}{replicated_engine} + """ + return create_db_sql + def apply_migrations( self, target_db: str, target_version: Optional[int] = None ) -> None: @@ -46,20 +114,15 @@ def apply_migrations( return logger.info(f"Migrations to apply: {migrations_to_apply}") if status["curr_version"] == 0: - self.ch_client.command(f"CREATE DATABASE IF NOT EXISTS {target_db}") + self.ch_client.command(self._create_db_sql(target_db)) for target_version, migration_file in migrations_to_apply: self._apply_migration(target_db, target_version, migration_file) if should_insert_costs(status["curr_version"], target_version): insert_costs(self.ch_client, target_db) def _initialize_migration_db(self) -> None: - self.ch_client.command( - """ - CREATE DATABASE IF NOT EXISTS db_management - """ - ) - self.ch_client.command( - """ + self.ch_client.command(self._create_db_sql("db_management")) + create_table_sql = """ CREATE TABLE IF NOT EXISTS db_management.migrations ( db_name String, @@ -69,7 +132,7 @@ def _initialize_migration_db(self) -> None: ENGINE = MergeTree() ORDER BY (db_name) """ - ) + self.ch_client.command(self._format_replicated_sql(create_table_sql)) def _get_migration_status(self, db_name: str) -> dict: column_names = ["db_name", "curr_version", "partially_applied_version"] @@ -184,31 +247,48 @@ def _determine_migrations_to_apply( return [] + def _execute_migration_command(self, target_db: str, command: str) -> None: + """Execute a single migration command in the context of the target database.""" + command = command.strip() + if len(command) == 0: + return + curr_db = self.ch_client.database + self.ch_client.database = target_db + self.ch_client.command(self._format_replicated_sql(command)) + self.ch_client.database = curr_db + + def _update_migration_status( + self, target_db: str, target_version: int, is_start: bool = True + ) -> None: + """Update the migration status in db_management.migrations table.""" + if is_start: + self.ch_client.command( + f"ALTER TABLE db_management.migrations UPDATE partially_applied_version = {target_version} WHERE db_name = '{target_db}'" + ) + else: + self.ch_client.command( + f"ALTER TABLE db_management.migrations UPDATE curr_version = {target_version}, partially_applied_version = NULL WHERE db_name = '{target_db}'" + ) + def _apply_migration( self, target_db: str, target_version: int, migration_file: str ) -> None: logger.info(f"Applying migration {migration_file} to `{target_db}`") migration_dir = os.path.join(os.path.dirname(__file__), "migrations") migration_file_path = os.path.join(migration_dir, migration_file) + with open(migration_file_path) as f: migration_sql = f.read() - self.ch_client.command( - f""" - ALTER TABLE db_management.migrations UPDATE partially_applied_version = {target_version} WHERE db_name = '{target_db}' - """ - ) + + # Mark migration as partially applied + self._update_migration_status(target_db, target_version, is_start=True) + + # Execute each command in the migration migration_sub_commands = migration_sql.split(";") for command in migration_sub_commands: - command = command.strip() - if len(command) == 0: - continue - curr_db = self.ch_client.database - self.ch_client.database = target_db - self.ch_client.command(command) - self.ch_client.database = curr_db - self.ch_client.command( - f""" - ALTER TABLE db_management.migrations UPDATE curr_version = {target_version}, partially_applied_version = NULL WHERE db_name = '{target_db}' - """ - ) + self._execute_migration_command(target_db, command) + + # Mark migration as fully applied + self._update_migration_status(target_db, target_version, is_start=False) + logger.info(f"Migration {migration_file} applied to `{target_db}`") From 4c6db183be9b356dfc6781670c5b8f5fb6fd7532 Mon Sep 17 00:00:00 2001 From: Josiah Lee Date: Wed, 11 Dec 2024 12:54:59 -0800 Subject: [PATCH 11/12] add choices drawer (#3203) --- .../Browse3/pages/ChatView/ChoiceView.tsx | 8 +- .../Browse3/pages/ChatView/ChoicesDrawer.tsx | 104 ++++++++++++++++++ .../Browse3/pages/ChatView/ChoicesView.tsx | 39 ++++--- .../pages/ChatView/ChoicesViewCarousel.tsx | 27 +++-- .../pages/ChatView/ChoicesViewLinear.tsx | 73 ------------ .../Home/Browse3/pages/ChatView/types.ts | 2 - 6 files changed, 148 insertions(+), 105 deletions(-) create mode 100644 weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesDrawer.tsx delete mode 100644 weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesViewLinear.tsx diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoiceView.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoiceView.tsx index c8143f9549e..d1a2c59d5d0 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoiceView.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoiceView.tsx @@ -6,15 +6,21 @@ import {Choice} from './types'; type ChoiceViewProps = { choice: Choice; isStructuredOutput?: boolean; + isNested?: boolean; }; -export const ChoiceView = ({choice, isStructuredOutput}: ChoiceViewProps) => { +export const ChoiceView = ({ + choice, + isStructuredOutput, + isNested, +}: ChoiceViewProps) => { const {message} = choice; return ( ); diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesDrawer.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesDrawer.tsx new file mode 100644 index 00000000000..16d6897c27b --- /dev/null +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesDrawer.tsx @@ -0,0 +1,104 @@ +import {Box, Drawer} from '@mui/material'; +import {MOON_200} from '@wandb/weave/common/css/color.styles'; +import {Tag} from '@wandb/weave/components/Tag'; +import {Tailwind} from '@wandb/weave/components/Tailwind'; +import React from 'react'; + +import {Button} from '../../../../../Button'; +import {ChoiceView} from './ChoiceView'; +import {Choice} from './types'; + +type ChoicesDrawerProps = { + choices: Choice[]; + isStructuredOutput?: boolean; + isDrawerOpen: boolean; + setIsDrawerOpen: React.Dispatch>; + selectedChoiceIndex: number; + setSelectedChoiceIndex: (choiceIndex: number) => void; +}; + +export const ChoicesDrawer = ({ + choices, + isStructuredOutput, + isDrawerOpen, + setIsDrawerOpen, + selectedChoiceIndex, + setSelectedChoiceIndex, +}: ChoicesDrawerProps) => { + return ( + setIsDrawerOpen(false)} + title="Choices" + anchor="right" + sx={{ + '& .MuiDrawer-paper': {mt: '60px', width: '400px'}, + }}> + + + Responses + + + ) : ( + + )} +
+ +
+ ))} +
+ + + ); +}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesView.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesView.tsx index c22df7c63d7..5ddc7f12202 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesView.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesView.tsx @@ -1,9 +1,9 @@ import React, {useState} from 'react'; +import {ChoicesDrawer} from './ChoicesDrawer'; import {ChoicesViewCarousel} from './ChoicesViewCarousel'; -import {ChoicesViewLinear} from './ChoicesViewLinear'; import {ChoiceView} from './ChoiceView'; -import {Choice, ChoicesMode} from './types'; +import {Choice} from './types'; type ChoicesViewProps = { choices: Choice[]; @@ -14,7 +14,12 @@ export const ChoicesView = ({ choices, isStructuredOutput, }: ChoicesViewProps) => { - const [mode, setMode] = useState('linear'); + const [isDrawerOpen, setIsDrawerOpen] = useState(false); + const [localSelectedChoiceIndex, setLocalSelectedChoiceIndex] = useState(0); + + const handleSetSelectedChoiceIndex = (choiceIndex: number) => { + setLocalSelectedChoiceIndex(choiceIndex); + }; if (choices.length === 0) { return null; @@ -26,20 +31,20 @@ export const ChoicesView = ({ } return ( <> - {mode === 'linear' && ( - - )} - {mode === 'carousel' && ( - - )} + + ); }; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesViewCarousel.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesViewCarousel.tsx index f4a52fc6801..a34932dea17 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesViewCarousel.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesViewCarousel.tsx @@ -1,34 +1,37 @@ -import React, {useState} from 'react'; +import React from 'react'; import {Button} from '../../../../../Button'; import {ChoiceView} from './ChoiceView'; -import {Choice, ChoicesMode} from './types'; +import {Choice} from './types'; type ChoicesViewCarouselProps = { choices: Choice[]; isStructuredOutput?: boolean; - setMode: React.Dispatch>; + setIsDrawerOpen: React.Dispatch>; + selectedChoiceIndex: number; + setSelectedChoiceIndex: (choiceIndex: number) => void; }; export const ChoicesViewCarousel = ({ choices, isStructuredOutput, - setMode, + setIsDrawerOpen, + selectedChoiceIndex, + setSelectedChoiceIndex, }: ChoicesViewCarouselProps) => { - const [step, setStep] = useState(0); - const onNext = () => { - setStep((step + 1) % choices.length); + setSelectedChoiceIndex((selectedChoiceIndex + 1) % choices.length); }; const onBack = () => { - const newStep = step === 0 ? choices.length - 1 : step - 1; - setStep(newStep); + const newStep = + selectedChoiceIndex === 0 ? choices.length - 1 : selectedChoiceIndex - 1; + setSelectedChoiceIndex(newStep); }; return ( <>
@@ -37,7 +40,7 @@ export const ChoicesViewCarousel = ({ size="small" variant="quiet" icon="expand-uncollapse" - onClick={() => setMode('linear')} + onClick={() => setIsDrawerOpen(true)} tooltip="Switch to linear view" />
@@ -48,7 +51,7 @@ export const ChoicesViewCarousel = ({ size="small" onClick={onBack} /> - {step + 1} of {choices.length} + {selectedChoiceIndex + 1} of {choices.length}