Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(ui): basic human feedback rendering in calls table #2991

Merged
merged 29 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
caf307b
wip
gtarpenning Nov 14, 2024
3252c78
wip
gtarpenning Nov 14, 2024
aae7ac7
lint
gtarpenning Nov 14, 2024
6f96e8a
working
gtarpenning Nov 14, 2024
e69741c
simpler
gtarpenning Nov 14, 2024
a6f470c
fetch and make annotation columns hidden by default
gtarpenning Nov 14, 2024
599e81f
undodeletecomment
gtarpenning Nov 14, 2024
9ce4add
prettier
gtarpenning Nov 15, 2024
61d26a1
lint
gtarpenning Nov 15, 2024
bd172ab
Merge branch 'master' into griffin/feedback-column-query
gtarpenning Nov 15, 2024
68e622e
josiah comments
gtarpenning Nov 15, 2024
2bb823c
Merge branch 'master' into griffin/feedback-column-query
gtarpenning Nov 18, 2024
189088a
fix two bugs
gtarpenning Nov 18, 2024
78d8cab
use feedback.[wandb.annotation.xxx] format
gtarpenning Nov 18, 2024
6e950b6
filtering is still broken...
gtarpenning Nov 18, 2024
0b76344
review-comments-minus-include_feedback
gtarpenning Nov 18, 2024
1e411b1
include_feedback
gtarpenning Nov 18, 2024
a9d5025
use better feedback structure
gtarpenning Nov 18, 2024
f8892f1
move
gtarpenning Nov 18, 2024
33a9949
fixed, working ish on filter and sort
gtarpenning Nov 18, 2024
695f882
better
gtarpenning Nov 18, 2024
b3177fd
better
gtarpenning Nov 19, 2024
fbc9bfe
comment
gtarpenning Nov 19, 2024
11a7c2f
better
gtarpenning Nov 19, 2024
24ae81e
comment
gtarpenning Nov 19, 2024
c74c117
boom
gtarpenning Nov 19, 2024
f2656d3
Merge branch 'master' into griffin/feedback-column-query
gtarpenning Nov 19, 2024
73bbbe9
Merge branch 'master' into griffin/feedback-column-query
gtarpenning Nov 19, 2024
f706426
reactions -> feedback
gtarpenning Nov 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import {CallSchema} from '../../wfReactInterface/wfDataModelHooksInterface';
import {DEFAULT_COST_DATA, isCostDataKey, isUsageDataKey} from './costTypes';

const COST_PARAM_PREFIX = 'summary.weave.costs.';
const USAGE_PARAM_PREFIX = 'summary.usage.';

export const getCostFromCellParams = (params: {[key: string]: any}) => {
const costData: {[key: string]: LLMCostSchema} = {};
Expand All @@ -29,8 +30,8 @@ export const getCostFromCellParams = (params: {[key: string]: any}) => {
export const getUsageFromCellParams = (params: {[key: string]: any}) => {
const usage: {[key: string]: LLMUsageSchema} = {};
for (const key in params) {
if (key.startsWith('summary.usage')) {
const usageKeys = key.replace('summary.usage.', '').split('.');
if (key.startsWith(USAGE_PARAM_PREFIX)) {
const usageKeys = key.replace(`${USAGE_PARAM_PREFIX}.`, '').split('.');
const usageKey = usageKeys.pop() || '';
if (isUsageDataKey(usageKey)) {
const model = usageKeys.join('.');
Expand Down Expand Up @@ -87,22 +88,45 @@ export const formatTokenCost = (cost: number): string => {
return `$${cost.toFixed(2)}`;
};

// TODO(Josiah): this is here because sometimes the cost query is not returning all the ids I believe for unfinished calls,
// to get this cost uptake out, this function can be removed, once that is fixed
export const addCostsToCallResults = (
callResults: CallSchema[],
costResults: CallSchema[]
) => {
const costDict = costResults.reduce((acc, cost) => {
if (cost.callId) {
acc[cost.callId] = cost;
): CallSchema[] => {
const costDict = costResults.reduce((acc, costResult) => {
if (costResult.callId) {
gtarpenning marked this conversation as resolved.
Show resolved Hide resolved
acc[costResult.callId] = {
summary: {
weave: {
costs: costResult.traceCall?.summary?.weave?.costs,
},
usage: costResult.traceCall?.summary?.usage,
},
};
}
return acc;
}, {} as Record<string, CallSchema>);
}, {} as Record<string, any>);

return callResults.map(call => {
if (call.callId && costDict[call.callId]) {
return {...call, ...costDict[call.callId]};
if (!call.traceCall) {
return call;
}
// Merge cost fields into existing call data
const merged = {
...call,
traceCall: {
...call.traceCall,
summary: {
...call.traceCall?.summary,
weave: {
...call.traceCall?.summary?.weave,
costs: costDict[call.callId].summary.weave.costs,
},
},
usage: costDict[call.callId].summary.usage,
},
};
return merged;
gtarpenning marked this conversation as resolved.
Show resolved Hide resolved
}
return call;
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,13 @@ import {
useControllableState,
useURLSearchParamsDict,
} from '../util';
import {useBaseObjectInstances} from '../wfReactInterface/baseObjectClassQuery';
import {useWFHooks} from '../wfReactInterface/context';
import {TraceCallSchema} from '../wfReactInterface/traceServerClientTypes';
import {traceCallToUICallSchema} from '../wfReactInterface/tsDataModelHooks';
import {
projectIdFromParts,
traceCallToUICallSchema,
} from '../wfReactInterface/tsDataModelHooks';
import {EXPANDED_REF_REF_KEY} from '../wfReactInterface/tsDataModelHooksCallRefExpansion';
import {objectVersionNiceString} from '../wfReactInterface/utilities';
import {CallSchema} from '../wfReactInterface/wfDataModelHooksInterface';
Expand Down Expand Up @@ -519,6 +523,50 @@ export const CallsTable: FC<{
project
);

// Fetch annotation columns and make them hidden by default
const annotationColumns = useBaseObjectInstances('AnnotationSpec', {
project_id: projectIdFromParts({entity, project}),
filter: {
latest_only: true,
},
});
useEffect(() => {
if (!annotationColumns.result || annotationColumns.result.length === 0) {
return;
}
if (!setColumnVisibilityModel || !columnVisibilityModel) {
return;
}
// Check if we need to update - only update if any annotation columns are missing from the model
const needsUpdate = annotationColumns.result.some(
col => columnVisibilityModel[`feedback.${col.object_id}`] === undefined
);

if (!needsUpdate) {
return;
}

const annotationColumnVisiblityFalse = annotationColumns.result.reduce(
(acc, col) => {
// Only add columns that aren't already in the model
if (columnVisibilityModel[`feedback.${col.object_id}`] === undefined) {
acc[`feedback.${col.object_id}`] = false;
}
return acc;
},
{} as Record<string, boolean>
);

setColumnVisibilityModel({
...columnVisibilityModel,
...annotationColumnVisiblityFalse,
});
}, [
annotationColumns.result,
columnVisibilityModel,
setColumnVisibilityModel,
]);

// Selection Management
const [selectedCalls, setSelectedCalls] = useState<string[]>([]);
const clearSelectedCalls = useCallback(() => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,8 @@ function buildCallsTableColumns(
},
},
{
field: 'feedback',
headerName: 'Feedback',
field: 'reactions',
headerName: 'Reactions',
width: 150,
sortable: false,
filterable: false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ export function isDynamicCallColumn(path: Path): boolean {
}
return (
path.length > 1 &&
['attributes', 'inputs', 'output', 'summary'].includes(path[0])
['attributes', 'inputs', 'output', 'summary', 'feedback'].includes(path[0])
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ import {
GridPaginationModel,
GridSortModel,
} from '@mui/x-data-grid-pro';
import {parseRef} from '@wandb/weave/react';
import {makeRefCall} from '@wandb/weave/util/refs';
import {useCallback, useMemo} from 'react';

import {useDeepMemo} from '../../../../../../hookUtils';
Expand Down Expand Up @@ -120,20 +122,66 @@ export const useCallsForQuery = (
callsStats.refetch();
}, [calls, callsStats, costs]);

// query for feeback
const {useFeedbackByTypeAndCallRefs} = useWFHooks();
const feedbackTypeSubstr = 'wandb.annotation.';
const feedbackQuery = useFeedbackByTypeAndCallRefs(
entity,
project,
feedbackTypeSubstr,
callResults.map(call => makeRefCall(entity, project, call.callId))
);

// map of callId to the latest feedback of each feedback_type
const feedbackByType: Record<string, Record<string, any>> | undefined =
useMemo(() => {
return feedbackQuery.result?.reduce(
(acc: Record<string, Record<string, any>>, curr) => {
const callId = parseRef(curr.weave_ref).artifactName;
if (!acc[callId]) {
acc[callId] = {};
}
// Store feedback by feedback_type, newer entries will overwrite older ones
if (curr.feedback_type) {
const feedbackName = curr.feedback_type.replace(
feedbackTypeSubstr,
''
);
acc[callId][feedbackName] = getNestedValue(curr.payload);
}
return acc;
},
{}
);
}, [feedbackQuery.result]);

return useMemo(() => {
if (calls.loading) {
return {
costsLoading: costs.loading,
loading: calls.loading,
result: [],
total: 0,
refetch,
};
}

return {
costsLoading: costs.loading,
loading: calls.loading,
// Return faster calls query results until cost query finishes
result: calls.loading
? []
: costResults.length > 0
? addCostsToCallResults(callResults, costResults)
: callResults,
result: mergeCallData(callResults, costResults, feedbackByType),
total,
refetch,
};
}, [callResults, calls.loading, total, costs.loading, costResults, refetch]);
}, [
callResults,
calls.loading,
total,
costs.loading,
costResults,
refetch,
feedbackByType,
]);
};

export const useFilterSortby = (
Expand Down Expand Up @@ -217,3 +265,46 @@ const convertHighLevelFilterToLowLevelFilter = (
: undefined,
};
};

// Move mergeCallData into the file directly since it's specific to this use case

const mergeCallData = (
baseCallResults: CallSchema[],
costResults: CallSchema[],
feedbackByType?: Record<string, Record<string, any>>
): CallSchema[] => {
// Start with base results
let result = baseCallResults;

// Add feedback if available
if (feedbackByType) {
result = result.map(call => ({
...call,
traceCall: call.traceCall
? {
...call.traceCall,
feedback: feedbackByType[call.callId],
}
: undefined,
}));
}

// Add costs if available
if (costResults.length > 0) {
result = addCostsToCallResults(result, costResults);
}

return result;
};
// Helper to safely get deeply nested values
const getNestedValue = <T>(obj: any, depth: number = 3): T | undefined => {
try {
let result = obj;
for (let i = 0; i < depth; i++) {
result = Object.values(result)[0];
}
return result as T;
} catch {
return undefined;
}
};
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,94 @@ const useFeedback = (
return {...result, refetch};
};

const useFeedbackByTypeAndCallRefs = (
entity: string,
project: string,
feedbackType: string,
callRefs: string[],
sortBy?: traceServerTypes.SortBy[]
) => {
const getTsClient = useGetTraceServerClientContext();

const [result, setResult] = useState<
LoadableWithError<traceServerTypes.Feedback[]>
>({
loading: false,
result: null,
error: null,
});
const [doReload, setDoReload] = useState(false);
const refetch = useCallback(() => {
setDoReload(true);
}, [setDoReload]);

const deepCallRefs = useDeepMemo(callRefs);

useEffect(() => {
let mounted = true;
if (doReload) {
setDoReload(false);
}
if (!deepCallRefs || deepCallRefs.length === 0) {
return;
}
setResult({loading: true, result: null, error: null});
getTsClient()
.feedbackQuery({
project_id: projectIdFromParts({
entity,
project,
}),
query: {
$expr: {
$and: [
{
$contains: {
input: {$getField: 'feedback_type'},
substr: {$literal: feedbackType},
},
},
{
$in: [
{$getField: 'weave_ref'},
deepCallRefs.map(ref => ({$literal: ref})),
],
},
],
},
},
sort_by: sortBy ?? [{field: 'created_at', direction: 'desc'}],
})
.then(res => {
if (!mounted) {
return;
}
if ('result' in res) {
setResult({loading: false, result: res.result, error: null});
}
})
.catch(err => {
if (!mounted) {
return;
}
setResult({loading: false, result: null, error: err});
});
return () => {
mounted = false;
};
}, [
deepCallRefs,
getTsClient,
doReload,
sortBy,
feedbackType,
entity,
project,
]);

return {...result, refetch};
};

const useOpVersion = (
// Null value skips
key: OpVersionKey | null
Expand Down Expand Up @@ -1720,6 +1808,7 @@ export const tsWFDataModelHooks: WFDataModelHooksInterface = {
useRefsData,
useApplyMutationsToRef,
useFeedback,
useFeedbackByTypeAndCallRefs,
useFileContent,
useTableRowsQuery,
useTableQueryStats,
Expand Down
Loading
Loading